Skip to content

Commit 4bc34f3

Browse files
committed
feat: Added NperGrp, axis to Qscheme
Signed-off-by: Brandon Groth <[email protected]>
1 parent fe8cd9d commit 4bc34f3

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

fms_mo/quant_refactor/base_quant.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,12 @@ def __init__(
5353
self,
5454
unit: str,
5555
symmetric: bool = True,
56-
Nch: int = None,
57-
Ngrp: int = None,
5856
single_sided: bool = False,
5957
qlevel_lowering: bool = True,
58+
Nch: int = None,
59+
Ngrp: int = None,
60+
NperGrp: int = None,
61+
axis: int = None,
6062
):
6163
"""
6264
Init Qscheme
@@ -69,6 +71,8 @@ def __init__(
6971
single_sided (bool, optional): Specify if clip values are positive. Defaults to False.
7072
qlevel_lowering (bool, optional): Specify lowering of quantized levels.
7173
Defaults to True.
74+
axis (int, optional): Specify which tensor dimension to quantize indiviually.
75+
Defaults to 0.
7276
7377
Raises:
7478
RuntimeError: New PyTorch qscheme found. Need to update.
@@ -101,17 +105,16 @@ def __init__(
101105
"perCh was selected without specifying Nch."
102106
)
103107
elif unit == "perGrp":
104-
if issubclass(type(Ngrp), int):
108+
if issubclass(type(NperGrp), int):
109+
assert NperGrp > 0, "Provided NperGrp is negative"
110+
self.NperGrp = NperGrp
111+
elif issubclass(type(Ngrp), int):
105112
assert Ngrp > 0, "Provided Ngrp is negative"
106113
self.Ngrp = Ngrp
107114
else:
108115
raise RuntimeError(
109-
"perGrp was selected without specifying Ngrp."
116+
"perGrp was selected without specifying Ngrp or NperGrp."
110117
)
111-
# perGrp can be across channels, but is not required
112-
if issubclass(type(Nch), int):
113-
assert Nch > 0, "Provided Nch is negative"
114-
self.Nch = Nch
115118

116119
self.single_sided = single_sided
117120
self.qlevel_lowering = qlevel_lowering
@@ -128,12 +131,13 @@ def __repr__(self):
128131
q_uint_str = f"qunit={self.q_unit}"
129132
symmetric_str = f", symmetric={self.symmetric}"
130133
Nch_str = f", Nch={self.Nch}" if self.Nch is not None else "",
131-
Ngrp_str = f", Ngrp={self.Ngrp}" if self.Ngrp is not None else "",
134+
Ngrp_str = f", NperGrp={self.NperGrp}" if self.NperGrp else "",
135+
NperGrp_str = f", Ngrp={self.NperGrp}" if self.NperGrp is not None else "",
132136
single_sided_str = f", single_sided={self.single_sided}"
133137
qlevel_lowering_str = f", qlevel_lowering={self.qlevel_lowering}"
134138
return (
135-
f"{self.__class__.__name__}({q_uint_str}{symmetric_str}{Nch_str}{Ngrp_str}"
136-
f"{single_sided_str}{qlevel_lowering_str})"
139+
f"{self.__class__.__name__}({q_uint_str}{symmetric_str}{Nch_str}"
140+
f"{Ngrp_str}{NperGrp_str}{single_sided_str}{qlevel_lowering_str})"
137141
)
138142

139143

0 commit comments

Comments
 (0)