Skip to content

Commit d141bcd

Browse files
committed
fix: Updates to Qscheme for perCh
Signed-off-by: Brandon Groth <[email protected]>
1 parent f02a5d6 commit d141bcd

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

fms_mo/quant_refactor/base_quant.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,11 @@ def __init__(
7979
RuntimeError: perCh or perGrp was selected without specifying Nch, Ngrp.
8080
RuntimeError: qscheme is not allowed, or could be a typo.
8181
"""
82-
# Init Nch/Ngrp to none incase they won't be set
82+
# Init perCh/perGrp vars to none incase they won't be set
8383
self.Nch = None
8484
self.Ngrp = None
85+
self.NperGrp = None
86+
self.axis = None
8587
if isinstance(unit, torch.qscheme):
8688
if "per_channel" in str(unit):
8789
self.q_unit = "perCh"
@@ -97,13 +99,19 @@ def __init__(
9799
self.q_unit = unit
98100
self.symmetric = symmetric
99101
if unit == "perCh":
100-
if issubclass(type(Nch), int):
102+
if Nch is not None and issubclass(type(Nch), int):
101103
assert Nch > 0, "Provided Nch is negative"
102104
self.Nch = Nch
103105
else:
104106
raise RuntimeError(
105107
"perCh was selected without specifying Nch."
106108
)
109+
if axis is not None and issubclass(type(axis), int):
110+
self.axis = axis
111+
else:
112+
raise RuntimeError(
113+
"perCh was selected without specifying channel axis dimension."
114+
)
107115
elif unit == "perGrp":
108116
if issubclass(type(NperGrp), int):
109117
assert NperGrp > 0, "Provided NperGrp is negative"
@@ -130,9 +138,9 @@ def __repr__(self):
130138
"""
131139
q_uint_str = f"qunit={self.q_unit}"
132140
symmetric_str = f", symmetric={self.symmetric}"
133-
Nch_str = f", Nch={self.Nch}" if self.Nch is not None else "",
134-
Ngrp_str = f", NperGrp={self.NperGrp}" if self.NperGrp else "",
135-
NperGrp_str = f", Ngrp={self.NperGrp}" if self.NperGrp is not None else "",
141+
Nch_str = f", Nch={self.Nch}" if self.Nch is not None else ""
142+
Ngrp_str = f", NperGrp={self.Ngrp}" if self.Ngrp is not None else ""
143+
NperGrp_str = f", Ngrp={self.NperGrp}" if self.NperGrp is not None else ""
136144
single_sided_str = f", single_sided={self.single_sided}"
137145
qlevel_lowering_str = f", qlevel_lowering={self.qlevel_lowering}"
138146
return (

0 commit comments

Comments
 (0)