Skip to content

Commit e0d1a0a

Browse files
committed
feat: Added per_channel_axis support to per_channel_ste
Signed-off-by: Brandon Groth <[email protected]>
1 parent 4bc34f3 commit e0d1a0a

File tree

1 file changed

+53
-10
lines changed

1 file changed

+53
-10
lines changed

fms_mo/quant_refactor/per_channel_ste.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
linear_quantization,
2626
symmetric_linear_quantization_params,
2727
transform_clips,
28+
per_channel_axis,
2829
)
2930

3031

@@ -50,6 +51,7 @@ def forward(
5051
dequantize: bool = True,
5152
symmetric: bool = False,
5253
qlevel_lowering: bool = False,
54+
axis: int = 0,
5355
):
5456
"""
5557
General forward method:
@@ -61,21 +63,23 @@ def forward(
6163
6264
Args:
6365
ctx (torch.autograd.Function): Forward/Backward context object.
64-
input_tensor_tensor (torch.FloatTensor): Tensor to be quantized.
66+
input_tensor (torch.FloatTensor): Tensor to be quantized.
6567
num_bits (torch.IntTensor): Number of bit for quantization.
6668
clip_valn (torch.FloatTensor): Lower clip value bound.
6769
clip_val (torch.FloatTensor): Upper clip value bound.
6870
dequantize (bool, optional): Return dequantized or int tensor. Defaults to True.
6971
symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False.
7072
qlevel_lowering (bool, optional): Specify lowering of quantized levels.
7173
Defaults to True.
74+
axis (int, optional): Specify which tensor dimension to quantize indiviually.
75+
Defaults to 0.
7276
7377
Returns:
7478
torch.Tensor: Dequantized or Quantized output tensor.
7579
"""
7680
clip_valn, clip_val = transform_clips(input_tensor.dtype, clip_valn, clip_val)
7781
n_levels, scale, zero_point = PerChannelSTE.calc_qparams(
78-
input_tensor, num_bits, clip_valn, clip_val, qlevel_lowering
82+
num_bits, clip_valn, clip_val, qlevel_lowering, axis, input_tensor.shape
7983
)
8084
PerChannelSTE.save_tensors(
8185
ctx,
@@ -100,6 +104,8 @@ def calc_qparams(
100104
clip_val: torch.FloatTensor,
101105
symmetric: bool = False,
102106
qlevel_lowering: bool = True,
107+
axis: int = 0,
108+
tensor_shape: torch.Size = None,
103109
):
104110
"""
105111
Compute the scale and zero_point from num_bits and clip values
@@ -111,6 +117,8 @@ def calc_qparams(
111117
symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False.
112118
qlevel_lowering (bool, optional): Specify lowering of quantized levels.
113119
Defaults to True.
120+
axis (int, optional): Specify which tensor dimension to quantize indiviually.
121+
Defaults to 0.
114122
115123
Returns:
116124
torch.IntTensor, torch.FloatTensor, torch.IntTensor: Quantized parameters
@@ -129,9 +137,24 @@ def calc_qparams(
129137
integral_zero_point=True,
130138
signed=False,
131139
)
132-
return n_levels, scale, zero_point
140+
141+
# Broadcast scale, zero_point based on axis
142+
scale, zero_point = per_channel_axis(scale, zero_point, tensor_shape, axis)
133143

144+
return n_levels, scale, zero_point
145+
146+
# The save_tensors and backward unpacking must be synced
147+
@classmethod
148+
def save_tensors(cls, ctx, tensors) -> None:
149+
"""
150+
Save computed data to ctx for backward()
134151
152+
Args:
153+
ctx (torch.autograd.Function): Forward/Backward context object.
154+
tensors (list(torch.Tensor)): List of tensors to save.
155+
"""
156+
ctx.save_for_backward(*tensors)
157+
135158
@staticmethod
136159
def backward(ctx, grad_output):
137160
"""
@@ -172,6 +195,7 @@ def forward(
172195
dequantize: bool = True,
173196
symmetric: bool = False,
174197
qlevel_lowering: bool = False,
198+
axis: int = 0,
175199
):
176200
"""
177201
General forward method:
@@ -190,6 +214,8 @@ def forward(
190214
symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False.
191215
qlevel_lowering (bool, optional): Specify lowering of quantized levels.
192216
Defaults to True.
217+
axis (int, optional): Specify which tensor dimension to quantize indiviually.
218+
Defaults to 0.
193219
194220
Returns:
195221
torch.Tensor: Dequantized or Quantized output tensor.
@@ -207,10 +233,10 @@ def forward(
207233
qint_h,
208234
qint_dtype,
209235
) = PerChannelSTE_PTnative.calc_qparams(
210-
num_bits, clip_valn, clip_val, symmetric, qlevel_lowering
236+
num_bits, clip_valn, clip_val, symmetric, qlevel_lowering,
211237
)
212238
output = PerChannelSTE_PTnative.linear_quantization(
213-
input_tensor, scale, zero_point, qint_l, qint_h, qint_dtype, dequantize
239+
input_tensor, scale, zero_point, qint_l, qint_h, qint_dtype, dequantize, axis
214240
)
215241
return output
216242

@@ -248,6 +274,7 @@ def calc_qparams(
248274
qint_l, qint_h, qint_dtype = PerChannelSTE_PTnative.qint_bounds(
249275
num_bits, zero_point, symmetric, qlevel_lowering
250276
)
277+
# Note: fake_quantize_per_channel_affine does not need matching dimensions for scale/zp to tensor
251278
return n_levels, scale, zero_point, qint_l, qint_h, qint_dtype
252279

253280
@classmethod
@@ -288,6 +315,7 @@ def linear_quantization(
288315
qint_h: int,
289316
qint_dtype: torch.dtype,
290317
dequantize: bool = True,
318+
axis: int = 0,
291319
) -> torch.Tensor:
292320
"""
293321
Linear quantization for PTnative STE
@@ -300,6 +328,8 @@ def linear_quantization(
300328
qint_h (int): Quantized integer upper clip value.
301329
qint_dtype (torch.dtype): Quantized integer dtype.
302330
dequantize (bool, optional): Specify to return fp or quantized int. Defaults to True.
331+
axis (int, optional): Specify which tensor dimension to quantize indiviually.
332+
Defaults to 0.
303333
304334
Returns:
305335
torch.Tensor: PTnative quantized or dequantized tensor.
@@ -310,7 +340,7 @@ def linear_quantization(
310340
input_tensor.float(),
311341
scale.float(),
312342
zero_point,
313-
axis=0,
343+
axis=axis,
314344
quant_min=qint_l,
315345
quant_max=qint_h,
316346
).to(input_tensor.dtype)
@@ -321,7 +351,7 @@ def linear_quantization(
321351
input_tensor.float(),
322352
scale.float(),
323353
zero_point,
324-
axis=0,
354+
axis=axis,
325355
dtype=qint_dtype
326356
)
327357
.int_repr()
@@ -362,6 +392,7 @@ def forward(
362392
symmetric: bool = False,
363393
qlevel_lowering: bool = False,
364394
use_code: bool = False,
395+
axis: int = 0,
365396
):
366397
"""
367398
General forward method:
@@ -387,7 +418,14 @@ def forward(
387418
"""
388419
clip_valn, clip_val = transform_clips(input_tensor.dtype, clip_valn, clip_val)
389420
n_levels, scale, zero_point = PerChannelSTESAWB.calc_qparams(
390-
input_tensor, num_bits, clip_valn, clip_val, qlevel_lowering, use_code
421+
num_bits,
422+
clip_valn,
423+
clip_val,
424+
symmetric,
425+
qlevel_lowering,
426+
axis,
427+
input_tensor.shape,
428+
use_code,
391429
)
392430
PerChannelSTE.save_tensors(
393431
ctx,
@@ -408,10 +446,12 @@ def forward(
408446
def calc_qparams(
409447
cls,
410448
num_bits: torch.IntTensor,
411-
clip_valn: torch.FloatTensor,
449+
_clip_valn: torch.FloatTensor,
412450
clip_val: torch.FloatTensor,
413451
symmetric: bool = False,
414452
qlevel_lowering: bool = True,
453+
axis: int = 0,
454+
tensor_shape: torch.Size = None,
415455
use_code: bool = False,
416456
):
417457
"""
@@ -440,7 +480,10 @@ def calc_qparams(
440480
num_bits, clip_val, qlevel_lowering
441481
)
442482

483+
# Broadcast scale, zero_point to tensor shape
484+
scale, zero_point = per_channel_axis(scale, zero_point, tensor_shape, axis)
485+
443486
output = n_levels, scale, zero_point
444487
else:
445488
raise ValueError("SAWB has non-symmetric Qscheme")
446-
return output
489+
return output

0 commit comments

Comments
 (0)