Skip to content

Commit 759dab5

Browse files
committed
fix: Renamed refactor _new files/funcs to _rc
Signed-off-by: Brandon Groth <[email protected]>
1 parent aebedb9 commit 759dab5

19 files changed

+284
-286
lines changed

fms_mo/quant_refactor/get_quantizer_new.py renamed to fms_mo/quant_refactor/get_quantizer_rc.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,16 @@
2929
to_fp8,
3030
)
3131
from fms_mo.quant_refactor.base_quant import Qscheme
32-
from fms_mo.quant_refactor.lsq_new import LSQPlus_new, LSQQuantization_new
33-
from fms_mo.quant_refactor.pact2_new import PACT2_new
34-
from fms_mo.quant_refactor.pact2sym_new import PACT2Sym_new
35-
from fms_mo.quant_refactor.pact_new import PACT_new
36-
from fms_mo.quant_refactor.pactplussym_new import PACTplusSym_new
37-
from fms_mo.quant_refactor.qmax_new import Qmax_new
38-
from fms_mo.quant_refactor.sawb_new import SAWB_new
32+
from fms_mo.quant_refactor.lsq_rc import LSQPlus_rc, LSQQuantization_rc
33+
from fms_mo.quant_refactor.pact2_rc import PACT2_rc
34+
from fms_mo.quant_refactor.pact2sym_rc import PACT2Sym_rc
35+
from fms_mo.quant_refactor.pact_rc import PACT_rc
36+
from fms_mo.quant_refactor.pactplussym_rc import PACTplusSym_rc
37+
from fms_mo.quant_refactor.qmax_rc import Qmax_rc
38+
from fms_mo.quant_refactor.sawb_rc import SAWB_rc
3939

4040

41-
def get_activation_quantizer_new(
41+
def get_activation_quantizer_rc(
4242
qa_mode: str = "PACT",
4343
nbits: int = 32,
4444
clip_val: torch.Tensor = None,
@@ -59,10 +59,10 @@ def get_activation_quantizer_new(
5959
"""
6060

6161
QPACTLUT = {
62-
"pact_uni": PACT_new,
63-
"pact_bi": PACT2_new,
64-
"pact+_uni": PACT_new,
65-
"pact+_bi": PACT2_new,
62+
"pact_uni": PACT_rc,
63+
"pact_bi": PACT2_rc,
64+
"pact+_uni": PACT_rc,
65+
"pact+_bi": PACT2_rc,
6666
}
6767
if "pact" in qa_mode and "sym" not in qa_mode:
6868
keyQact = qa_mode + "_uni" if non_neg else qa_mode + "_bi"
@@ -84,7 +84,7 @@ def get_activation_quantizer_new(
8484
use_PT_native_Qfunc=use_PT_native_Qfunc,
8585
)
8686
elif qa_mode == "pactsym":
87-
act_quantizer = PACT2Sym_new(
87+
act_quantizer = PACT2Sym_rc(
8888
num_bits=nbits,
8989
init_clip_val=clip_val,
9090
Qscheme=Qscheme(
@@ -99,7 +99,7 @@ def get_activation_quantizer_new(
9999
use_PT_native_Qfunc=use_PT_native_Qfunc,
100100
)
101101
elif qa_mode == "pactsym+":
102-
act_quantizer = PACTplusSym_new(
102+
act_quantizer = PACTplusSym_rc(
103103
nbits,
104104
init_clip_val=clip_val,
105105
Qscheme=Qscheme(
@@ -115,7 +115,7 @@ def get_activation_quantizer_new(
115115
use_PT_native_Qfunc=use_PT_native_Qfunc,
116116
)
117117
elif qa_mode == "lsq+":
118-
act_quantizer = LSQPlus_new(
118+
act_quantizer = LSQPlus_rc(
119119
nbits,
120120
init_clip_valn=clip_valn,
121121
init_clip_val=clip_val,
@@ -131,12 +131,12 @@ def get_activation_quantizer_new(
131131
# use_PT_native_Qfunc=use_PT_native_Qfunc, # nativePT not enabled for LSQ+
132132
)
133133
elif qa_mode == "lsq":
134-
act_quantizer = LSQQuantization_new(
134+
act_quantizer = LSQQuantization_rc(
135135
nbits, init_clip_val=clip_val, dequantize=True, inplace=False
136136
)
137137
# NOTE: need to be careful using this for activation, particular to 1 sided.
138138
elif qa_mode == "max":
139-
act_quantizer = Qmax_new(
139+
act_quantizer = Qmax_rc(
140140
nbits,
141141
Qscheme=Qscheme(
142142
unit="perT",
@@ -151,7 +151,7 @@ def get_activation_quantizer_new(
151151
use_PT_native_Qfunc=use_PT_native_Qfunc,
152152
)
153153
elif qa_mode == "minmax":
154-
act_quantizer = Qmax_new(
154+
act_quantizer = Qmax_rc(
155155
nbits,
156156
Qscheme=Qscheme(
157157
unit="perT",
@@ -166,7 +166,7 @@ def get_activation_quantizer_new(
166166
use_PT_native_Qfunc=use_PT_native_Qfunc,
167167
)
168168
elif qa_mode == "maxsym":
169-
act_quantizer = Qmax_new(
169+
act_quantizer = Qmax_rc(
170170
nbits,
171171
Qscheme=Qscheme(
172172
unit="perT",
@@ -212,7 +212,7 @@ def get_activation_quantizer_new(
212212
return act_quantizer
213213

214214

215-
def get_weight_quantizer_new(
215+
def get_weight_quantizer_rc(
216216
qw_mode: str = "SAWB+",
217217
nbits: int = 32,
218218
clip_val: torch.Tensor = None,
@@ -240,7 +240,7 @@ def get_weight_quantizer_new(
240240
unit = "perCh" if Nch is not False else "perGrp" if perGp is not None else "perT"
241241
if "sawb" in qw_mode:
242242
clipSTE = "+" in qw_mode
243-
weight_quantizer = SAWB_new(
243+
weight_quantizer = SAWB_rc(
244244
nbits,
245245
Qscheme=Qscheme(
246246
unit=unit,
@@ -255,7 +255,7 @@ def get_weight_quantizer_new(
255255
use_PT_native_Qfunc=use_PT_native_Qfunc,
256256
)
257257
elif "max" in qw_mode:
258-
weight_quantizer = Qmax_new(
258+
weight_quantizer = Qmax_rc(
259259
nbits,
260260
Qscheme=Qscheme(
261261
unit=unit,
@@ -269,7 +269,7 @@ def get_weight_quantizer_new(
269269
use_PT_native_Qfunc=use_PT_native_Qfunc,
270270
)
271271
elif qw_mode == "pact":
272-
weight_quantizer = PACT2_new(
272+
weight_quantizer = PACT2_rc(
273273
nbits,
274274
init_clip_valn=clip_valn,
275275
init_clip_val=clip_val,
@@ -285,7 +285,7 @@ def get_weight_quantizer_new(
285285
use_PT_native_Qfunc=use_PT_native_Qfunc,
286286
)
287287
elif qw_mode == "pact+":
288-
weight_quantizer = PACTplusSym_new(
288+
weight_quantizer = PACTplusSym_rc(
289289
nbits,
290290
init_clip_val=clip_val,
291291
Qscheme=Qscheme(
@@ -300,7 +300,7 @@ def get_weight_quantizer_new(
300300
use_PT_native_Qfunc=use_PT_native_Qfunc,
301301
)
302302
elif qw_mode == "lsq+":
303-
weight_quantizer = LSQPlus_new(
303+
weight_quantizer = LSQPlus_rc(
304304
nbits,
305305
init_clip_valb=clip_valn,
306306
init_clip_vals=clip_val,
Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
)
5151

5252

53-
class LSQQuantization_new(Quantizer):
53+
class LSQQuantization_rc(Quantizer):
5454
"""
5555
LSQ Quantizer
5656
@@ -94,12 +94,12 @@ def __init__(
9494

9595
def set_quantizer(self):
9696
"""
97-
Set quantizer STE - use LSQQuantizationSTE_new
97+
Set quantizer STE - use LSQQuantizationSTE_rc
9898
"""
99-
self.quantizer = LSQQuantizationSTE_new
99+
self.quantizer = LSQQuantizationSTE_rc
100100

101101

102-
class LSQQuantizationSTE_new(PerTensorSTE):
102+
class LSQQuantizationSTE_rc(PerTensorSTE):
103103
"""
104104
1-sided LSQ quantization STE
105105
@@ -192,7 +192,7 @@ def backward(ctx, grad_output):
192192
return grad_input, grad_alpha * grad_scale, None, None, None, None, None
193193

194194

195-
class LSQPlus_new(Quantizer):
195+
class LSQPlus_rc(Quantizer):
196196
"""
197197
LSQ+ Quantizater
198198
@@ -242,9 +242,9 @@ def __init__(
242242

243243
def set_quantizer(self):
244244
"""
245-
Set quantizer STE - use LSQPlus_func_new
245+
Set quantizer STE - use LSQPlus_func_rc
246246
"""
247-
self.quantizer = LSQPlus_func_new
247+
self.quantizer = LSQPlus_func_rc
248248

249249
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
250250
"""
@@ -283,7 +283,7 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
283283
return output
284284

285285

286-
class LSQPlus_func_new(torch.autograd.Function):
286+
class LSQPlus_func_rc(torch.autograd.Function):
287287
"""2-side LSQ+ from CVPR workshop paper"""
288288

289289
@staticmethod
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
)
3636

3737

38-
class PACT2_new(Quantizer):
38+
class PACT2_rc(Quantizer):
3939
"""
4040
Two-sided original PACT
4141
PACT2 can be used to quantize both weights and activations
@@ -88,10 +88,10 @@ def set_quantizer(self):
8888
if self.use_PT_native_Qfunc:
8989
self.quantizer = PerTensorSTE_PTnative
9090
else:
91-
self.quantizer = PACTplus2STE_new if self.pact_plus else PACT2_STE_new
91+
self.quantizer = PACTplus2STE_rc if self.pact_plus else PACT2_STE_rc
9292

9393

94-
class PACT2_STE_new(PerTensorSTE):
94+
class PACT2_STE_rc(PerTensorSTE):
9595
"""
9696
two-sided original pact quantization for activation
9797
@@ -137,7 +137,7 @@ def backward(ctx, grad_output):
137137
return grad_input, grad_alpha, grad_alphan, None, None, None, None
138138

139139

140-
class PACTplus2STE_new(PerTensorSTE):
140+
class PACTplus2STE_rc(PerTensorSTE):
141141
"""
142142
two-sided pact+ quantization for activation
143143
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
)
3636

3737

38-
class PACT2Sym_new(Quantizer):
38+
class PACT2Sym_rc(Quantizer):
3939
"""
4040
Two-sided PACT with symmetric clip values
4141
@@ -86,10 +86,10 @@ def set_quantizer(self):
8686
if self.use_PT_native_Qfunc:
8787
self.quantizer = PerTensorSTE_PTnative
8888
else:
89-
self.quantizer = PACT2Sym_STE_new
89+
self.quantizer = PACT2Sym_STE_rc
9090

9191

92-
class PACT2Sym_STE_new(PerTensorSTE):
92+
class PACT2Sym_STE_rc(PerTensorSTE):
9393
"""
9494
Symmetric with zero in the center. For example, 4bit -- > [-7, 7] with FP0 align to INT0
9595
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
)
3636

3737

38-
class PACT_new(Quantizer):
38+
class PACT_rc(Quantizer):
3939
"""
4040
1-sided original PACT
4141
PACT is only used to quantize activations
@@ -89,10 +89,10 @@ def set_quantizer(self):
8989
if self.use_PT_native_Qfunc: # PTnative overrides all other options
9090
self.quantizer = PerTensorSTE_PTnative
9191
else:
92-
self.quantizer = PACTplusSTE_new if self.pact_plus else PACT_STE_new
92+
self.quantizer = PACTplusSTE_rc if self.pact_plus else PACT_STE_rc
9393

9494

95-
class PACT_STE_new(PerTensorSTE):
95+
class PACT_STE_rc(PerTensorSTE):
9696
"""
9797
Single-sided PACT STE
9898
@@ -128,7 +128,7 @@ def backward(ctx, grad_output):
128128
return grad_input, grad_alpha, None, None, None, None, None
129129

130130

131-
class PACTplusSTE_new(PerTensorSTE):
131+
class PACTplusSTE_rc(PerTensorSTE):
132132
"""
133133
Single-sided PACT+ STE
134134
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
)
3636

3737

38-
class PACTplusSym_new(Quantizer):
38+
class PACTplusSym_rc(Quantizer):
3939
"""
4040
Two-sided symmetric PACT+
4141
PACTplusSym can be used to quantize both weights and activations
@@ -95,10 +95,10 @@ def set_quantizer(self):
9595
self.quantizer = PerTensorSTE_PTnative
9696
else:
9797
if self.extend_act_range:
98-
self.quantizer = PACTplusExtendRangeSTE_new
98+
self.quantizer = PACTplusExtendRangeSTE_rc
9999
self.quantizer_name = "PACT+extend"
100100
else:
101-
self.quantizer = PACTplusSymSTE_new
101+
self.quantizer = PACTplusSymSTE_rc
102102
self.quantizer_name = "PACT+sym"
103103

104104
def set_extend_act_range(self, extend_act_range: bool):
@@ -111,7 +111,7 @@ def set_extend_act_range(self, extend_act_range: bool):
111111
self.extend_act_range = extend_act_range
112112

113113

114-
class PACTplusSymSTE_new(PerTensorSTE):
114+
class PACTplusSymSTE_rc(PerTensorSTE):
115115
"""
116116
Symmetric 2-sided PACT+
117117
@@ -157,7 +157,7 @@ def backward(ctx, grad_output):
157157
return grad_input, grad_alpha, None, None, None, None, None
158158

159159

160-
class PACTplusExtendRangeSTE_new(torch.autograd.Function):
160+
class PACTplusExtendRangeSTE_rc(torch.autograd.Function):
161161
"""
162162
2-sided PACT+ using a single clip
163163

0 commit comments

Comments
 (0)