Skip to content

Commit dd1d8f7

Browse files
committed
fix: Minor changes to quantizer, qmax, sawb
Signed-off-by: Brandon Groth <[email protected]>
1 parent ff5c375 commit dd1d8f7

File tree

3 files changed

+13
-8
lines changed

3 files changed

+13
-8
lines changed

fms_mo/quant_refactor/base_quant.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ def __init__(
7575
RuntimeError: perCh or perGrp was selected without specifying Nch, Ngrp.
7676
RuntimeError: qscheme is not allowed, or could be a typo.
7777
"""
78+
# Init Nch/Ngrp to none incase they won't be set
79+
self.Nch = None
80+
self.Ngrp = None
7881
if isinstance(unit, torch.qscheme):
7982
if "per_channel" in str(unit):
8083
self.q_unit = "perCh"
@@ -124,8 +127,8 @@ def __repr__(self):
124127
"""
125128
q_uint_str = f"qunit={self.q_unit}"
126129
symmetric_str = f", symmetric={self.symmetric}"
127-
Nch_str = f", Nch={self.Nch}",
128-
Ngrp_str = f", Nch={self.Ngrp}",
130+
Nch_str = f", Nch={self.Nch}" if self.Nch is not None else "",
131+
Ngrp_str = f", Ngrp={self.Ngrp}" if self.Ngrp is not None else "",
129132
single_sided_str = f", single_sided={self.single_sided}"
130133
qlevel_lowering_str = f", qlevel_lowering={self.qlevel_lowering}"
131134
return (

fms_mo/quant_refactor/qmax_new.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,10 @@ def __init__(
8989
use_PT_native_Qfunc=kwargs.get("use_PT_native_Qfunc", False),
9090
)
9191

92-
with torch.no_grad():
93-
self.clip_valn.data *= init_clip_valn
94-
self.clip_val.data *= init_clip_val
92+
if not self.training:
93+
with torch.no_grad():
94+
self.clip_valn.data *= init_clip_valn
95+
self.clip_val.data *= init_clip_val
9596

9697
self.align_zero = align_zero
9798
self.clipSTE = clipSTE

fms_mo/quant_refactor/sawb_new.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,10 @@ def __init__(
9292
use_PT_native_Qfunc=kwargs.get("use_PT_native_Qfunc", False),
9393
)
9494

95-
with torch.no_grad():
96-
self.clip_valn.data *= init_clip_valn
97-
self.clip_val.data *= init_clip_val
95+
if not self.training:
96+
with torch.no_grad():
97+
self.clip_valn.data *= init_clip_valn
98+
self.clip_val.data *= init_clip_val
9899

99100
self.clipSTE = clipSTE
100101
self.align_zero = align_zero

0 commit comments

Comments
 (0)