Skip to content

Commit 7c87e01

Browse files
committed
chore: formatted quant_refactor
Signed-off-by: Brandon Groth <[email protected]>
1 parent 99fe7c7 commit 7c87e01

21 files changed

+260
-431
lines changed

fms_mo/quant/quantizers.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,9 @@ def forward(
510510

511511
if istraining:
512512
# only recalc clipvals under training mode
513-
num_bits_int = num_bits.item() if isinstance(num_bits, torch.Tensor) else num_bits
513+
num_bits_int = (
514+
num_bits.item() if isinstance(num_bits, torch.Tensor) else num_bits
515+
)
514516
SAWBcode_mapping = {8: 803, 4: 403, 2: 103}
515517
if num_bits in [2, 4, 8]:
516518
sawb_code = SAWBcode_mapping[num_bits_int]
@@ -550,9 +552,13 @@ def forward(
550552
clip_val.dtype
551553
) # NOTE return will be a fp32 tensor; function only support float()
552554
else:
553-
output = torch.quantize_per_channel(
554-
input_tensor, scale, zero_point, 0, torch.qint8
555-
).int_repr().clamp(int_l, int_u)
555+
output = (
556+
torch.quantize_per_channel(
557+
input_tensor, scale, zero_point, 0, torch.qint8
558+
)
559+
.int_repr()
560+
.clamp(int_l, int_u)
561+
)
556562
# NOTE return will be a torch.int8 tensor
557563

558564
return output

fms_mo/quant_refactor/base_quant.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,7 @@ def __init__(
103103
assert Nch > 0, "Provided Nch is negative"
104104
self.Nch = Nch
105105
else:
106-
raise RuntimeError(
107-
"perCh was selected without specifying Nch."
108-
)
106+
raise RuntimeError("perCh was selected without specifying Nch.")
109107
if axis is not None and issubclass(type(axis), int):
110108
self.axis = axis
111109
else:
@@ -220,7 +218,9 @@ def __init__(
220218
self.align_zero = align_zero
221219
self.clipSTE = clipSTE
222220

223-
temp_clipvals = torch.ones(self.qscheme.Nch) if self.perCh else torch.Tensor([1.0])
221+
temp_clipvals = (
222+
torch.ones(self.qscheme.Nch) if self.perCh else torch.Tensor([1.0])
223+
)
224224
self.register_parameter("clip_val", torch.nn.Parameter(temp_clipvals.clone()))
225225
# Keep clip_valn as positive 1.0 to allow simpler multiplication with
226226
# negative numbers (clip_valn.data *= clip_valn)

fms_mo/quant_refactor/get_quantizer_new.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Functions to create quantizers for activation and weights. Called from Qmodule level.
1717
"""
1818

19+
# Third Party
1920
import torch
2021

2122
# Local
@@ -38,15 +39,15 @@
3839

3940

4041
def get_activation_quantizer_new(
41-
qa_mode:str="PACT",
42-
nbits:int=32,
43-
clip_val:torch.FloatTensor=None,
44-
clip_valn:torch.FloatTensor=None,
45-
non_neg:bool=False,
46-
align_zero:bool=True, # pylint: disable=unused-argument
47-
extend_act_range:bool=False,
48-
use_PT_native_Qfunc:bool=False,
49-
use_subnormal:bool=False,
42+
qa_mode: str = "PACT",
43+
nbits: int = 32,
44+
clip_val: torch.FloatTensor = None,
45+
clip_valn: torch.FloatTensor = None,
46+
non_neg: bool = False,
47+
align_zero: bool = True, # pylint: disable=unused-argument
48+
extend_act_range: bool = False,
49+
use_PT_native_Qfunc: bool = False,
50+
use_subnormal: bool = False,
5051
):
5152
"""Return a quantizer for activation quantization
5253
Regular quantizers:
@@ -212,16 +213,16 @@ def get_activation_quantizer_new(
212213

213214

214215
def get_weight_quantizer_new(
215-
qw_mode:str="SAWB+",
216-
nbits:int=32,
217-
clip_val:torch.FloatTensor=None,
218-
clip_valn:torch.FloatTensor=None,
219-
align_zero:bool=True,
220-
w_shape:torch.Size=None,
221-
recompute:bool=False, # pylint: disable=unused-argument
222-
perGp:int=None,
223-
use_PT_native_Qfunc:bool=False,
224-
use_subnormal:bool=False,
216+
qw_mode: str = "SAWB+",
217+
nbits: int = 32,
218+
clip_val: torch.FloatTensor = None,
219+
clip_valn: torch.FloatTensor = None,
220+
align_zero: bool = True,
221+
w_shape: torch.Size = None,
222+
recompute: bool = False, # pylint: disable=unused-argument
223+
perGp: int = None,
224+
use_PT_native_Qfunc: bool = False,
225+
use_subnormal: bool = False,
225226
):
226227
"""Return a quantizer for weight quantization
227228
Regular quantizers:
@@ -236,13 +237,7 @@ def get_weight_quantizer_new(
236237
Ngrp = (
237238
[w_shape[0] * w_shape[1] // perGp, perGp] if "perGp" in qw_mode else False
238239
) # store clip_val size and group size
239-
unit = (
240-
"perCh"
241-
if Nch is not False
242-
else "perGrp"
243-
if perGp is not None
244-
else "perT"
245-
)
240+
unit = "perCh" if Nch is not False else "perGrp" if perGp is not None else "perT"
246241
if "sawb" in qw_mode:
247242
clipSTE = "+" in qw_mode
248243
weight_quantizer = SAWB_new(
@@ -260,7 +255,6 @@ def get_weight_quantizer_new(
260255
use_PT_native_Qfunc=use_PT_native_Qfunc,
261256
)
262257
elif "max" in qw_mode:
263-
264258
weight_quantizer = Qmax_new(
265259
nbits,
266260
Qscheme=Qscheme(

fms_mo/quant_refactor/linear_utils.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
# limitations under the License.
1414

1515
"""
16-
Linear Quantization Utility functions
16+
Linear Quantization Utility functions
1717
1818
Raises:
1919
ValueError: Lower clip value is less than 0 for symmetric quantization
2020
"""
2121

22+
# Standard
2223
from typing import Tuple
2324

2425
# Third Party
@@ -288,7 +289,7 @@ def asymmetric_linear_quantization_params(
288289
diff = sat_max - sat_min
289290
# If float values are all 0, we just want the quantized values to be 0 as well.
290291
# So overriding the saturation value to 'n', so the scale becomes 1
291-
diff[ diff == 0.0 ] = n_levels
292+
diff[diff == 0.0] = n_levels
292293
scale = diff / n_levels
293294
zero_point = -sat_min / scale
294295
if integral_zero_point:
@@ -310,7 +311,7 @@ def symmetric_linear_quantization_params(
310311
num_bits (torch.IntTensor): Number of bits for quantization.
311312
sat_max (torch.FloatTensor): Upper clip value. Can be multi-valued (perCh/perGp).
312313
qlevel_lowering (bool, optional): Specify lowering of quantized levels. Defaults to False.
313-
Ngp_or_ch (int, optional):
314+
Ngp_or_ch (int, optional):
314315
315316
Returns:
316317
[torch.IntTensor, torch.FloatTensor, torch.FloatTensor]:
@@ -326,11 +327,12 @@ def symmetric_linear_quantization_params(
326327
# If float values are all 0, we just want the quantized values to be 0 as well.
327328
# So overriding the saturationvalue to '2n', so the scale becomes 1
328329
diff = 2 * sat_val
329-
diff[ diff == 0.0 ] = n_levels
330+
diff[diff == 0.0] = n_levels
330331
scale = diff / n_levels
331332
zero_point = torch.zeros_like(scale)
332333
return n_levels, scale, zero_point
333-
334+
335+
334336
def per_channel_axis(
335337
scale: torch.FloatTensor,
336338
zero_point: torch.IntTensor,
@@ -349,7 +351,7 @@ def per_channel_axis(
349351
tensor_shape (torch.Size): Shape of quantized tensor
350352
351353
Returns:
352-
scale, zero_point:
354+
scale, zero_point:
353355
"""
354356
if axis == 0:
355357
scale = scale.unsqueeze(1)
@@ -359,9 +361,9 @@ def per_channel_axis(
359361
zero_point = zero_point.unsqueeze(0)
360362
else:
361363
raise ValueError("Axis must be 0 or 1")
362-
364+
363365
# Check that tensor shape axis is same as scale/zp broadcast
364366
assert tensor_shape[axis] == scale.shape[axis]
365367
assert tensor_shape[axis] == zero_point.shape[axis]
366-
368+
367369
return scale, zero_point

fms_mo/quant_refactor/lsq_new.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@
2929
import torch
3030

3131
# Local
32-
from fms_mo.quant_refactor.base_quant import Quantizer, Qscheme
33-
from fms_mo.quant_refactor.per_tensor_ste import PerTensorSTE
32+
from fms_mo.quant_refactor.base_quant import Qscheme, Quantizer
3433
from fms_mo.quant_refactor.linear_utils import (
3534
asymmetric_linear_quantization_params,
3635
linear_dequantize,
3736
linear_quantize_LSQresidual,
3837
qint_bounds,
3938
)
39+
from fms_mo.quant_refactor.per_tensor_ste import PerTensorSTE
4040

4141
clip_valn_default = torch.tensor(-8.0)
4242
clip_val_default = torch.tensor(8.0)
@@ -49,6 +49,7 @@
4949
qlevel_lowering=False,
5050
)
5151

52+
5253
class LSQQuantization_new(Quantizer):
5354
"""
5455
LSQ Quantizer
@@ -64,7 +65,7 @@ def __init__(
6465
init_clip_val: torch.FloatTensor = clip_val_default,
6566
qscheme=qscheme_per_tensor,
6667
dequantize: bool = True,
67-
**kwargs
68+
**kwargs,
6869
):
6970
"""
7071
Init LSQ Quantizer
@@ -135,8 +136,9 @@ def forward(
135136
torch.Tensor: Dequantized or Quantized output tensor.
136137
"""
137138

138-
clip_valn, clip_val = clip_valn.to(input_tensor.dtype), clip_val.to(
139-
input_tensor.dtype
139+
clip_valn, clip_val = (
140+
clip_valn.to(input_tensor.dtype),
141+
clip_val.to(input_tensor.dtype),
140142
)
141143

142144
n_levels, scale, zero_point = asymmetric_linear_quantization_params(
@@ -205,7 +207,7 @@ def __init__(
205207
init_clip_val: torch.FloatTensor = clip_val_default,
206208
qscheme=qscheme_per_tensor,
207209
dequantize: bool = True,
208-
**kwargs
210+
**kwargs,
209211
):
210212
"""
211213
Init LSQ+ Quantizer

fms_mo/quant_refactor/pact2_new.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,8 @@
2020
import torch
2121

2222
# Local
23-
from fms_mo.quant_refactor.base_quant import Quantizer, Qscheme
24-
from fms_mo.quant_refactor.per_tensor_ste import (
25-
PerTensorSTE,
26-
PerTensorSTE_PTnative,
27-
)
23+
from fms_mo.quant_refactor.base_quant import Qscheme, Quantizer
24+
from fms_mo.quant_refactor.per_tensor_ste import PerTensorSTE, PerTensorSTE_PTnative
2825

2926
clip_valn_default = torch.tensor(-8.0)
3027
clip_val_default = torch.tensor(8.0)
@@ -37,6 +34,7 @@
3734
qlevel_lowering=False,
3835
)
3936

37+
4038
class PACT2_new(Quantizer):
4139
"""
4240
Two-sided original PACT
@@ -52,7 +50,7 @@ def __init__(
5250
qscheme: Qscheme = qscheme_per_tensor,
5351
dequantize: bool = True,
5452
pact_plus: bool = True,
55-
**kwargs
53+
**kwargs,
5654
):
5755
"""
5856
Init PACT2 quantizer

fms_mo/quant_refactor/pact2sym_new.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,8 @@
2020
import torch
2121

2222
# Local
23-
from fms_mo.quant_refactor.base_quant import Quantizer, Qscheme
24-
from fms_mo.quant_refactor.per_tensor_ste import (
25-
PerTensorSTE,
26-
PerTensorSTE_PTnative,
27-
)
23+
from fms_mo.quant_refactor.base_quant import Qscheme, Quantizer
24+
from fms_mo.quant_refactor.per_tensor_ste import PerTensorSTE, PerTensorSTE_PTnative
2825

2926
clip_valn_default = torch.tensor(-8.0)
3027
clip_val_default = torch.tensor(8.0)
@@ -37,6 +34,7 @@
3734
qlevel_lowering=False,
3835
)
3936

37+
4038
class PACT2Sym_new(Quantizer):
4139
"""
4240
Two-sided PACT with symmetric clip values
@@ -52,7 +50,7 @@ def __init__(
5250
init_clip_val: torch.FloatTensor = clip_val_default,
5351
qscheme: Qscheme = qscheme_per_tensor,
5452
dequantize: bool = True,
55-
**kwargs
53+
**kwargs,
5654
):
5755
"""
5856
Init PACT2Sym quantizer

fms_mo/quant_refactor/pact_new.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,8 @@
2020
import torch
2121

2222
# Local
23-
from fms_mo.quant_refactor.base_quant import Quantizer, Qscheme
24-
from fms_mo.quant_refactor.per_tensor_ste import (
25-
PerTensorSTE,
26-
PerTensorSTE_PTnative,
27-
)
23+
from fms_mo.quant_refactor.base_quant import Qscheme, Quantizer
24+
from fms_mo.quant_refactor.per_tensor_ste import PerTensorSTE, PerTensorSTE_PTnative
2825

2926
clip_valn_default = torch.tensor(0.0)
3027
clip_val_default = torch.tensor(8.0)
@@ -37,6 +34,7 @@
3734
qlevel_lowering=False,
3835
)
3936

37+
4038
class PACT_new(Quantizer):
4139
"""
4240
1-sided original PACT
@@ -54,7 +52,7 @@ def __init__(
5452
qscheme: Qscheme = qscheme_per_tensor,
5553
dequantize: bool = True,
5654
pact_plus: bool = True,
57-
**kwargs
55+
**kwargs,
5856
):
5957
"""
6058
Initialize PACT quantizer

fms_mo/quant_refactor/pactplussym_new.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,8 @@
2020
import torch
2121

2222
# Local
23-
from fms_mo.quant_refactor.base_quant import Quantizer, Qscheme
24-
from fms_mo.quant_refactor.per_tensor_ste import (
25-
PerTensorSTE,
26-
PerTensorSTE_PTnative,
27-
)
23+
from fms_mo.quant_refactor.base_quant import Qscheme, Quantizer
24+
from fms_mo.quant_refactor.per_tensor_ste import PerTensorSTE, PerTensorSTE_PTnative
2825

2926
clip_valn_default = torch.tensor(-8.0)
3027
clip_val_default = torch.tensor(8.0)
@@ -37,6 +34,7 @@
3734
qlevel_lowering=False,
3835
)
3936

37+
4038
class PACTplusSym_new(Quantizer):
4139
"""
4240
Two-sided symmetric PACT+
@@ -54,7 +52,7 @@ def __init__(
5452
qscheme: Qscheme = qscheme_per_tensor,
5553
dequantize: bool = True,
5654
extend_act_range: bool = False,
57-
**kwargs
55+
**kwargs,
5856
):
5957
"""
6058
Init PACT+Sym quantizer

0 commit comments

Comments
 (0)