2525 linear_quantization ,
2626 symmetric_linear_quantization_params ,
2727 transform_clips ,
28+ per_channel_axis ,
2829)
2930
3031
@@ -50,6 +51,7 @@ def forward(
5051 dequantize : bool = True ,
5152 symmetric : bool = False ,
5253 qlevel_lowering : bool = False ,
54+ axis : int = 0 ,
5355 ):
5456 """
5557 General forward method:
@@ -61,21 +63,23 @@ def forward(
6163
6264 Args:
6365 ctx (torch.autograd.Function): Forward/Backward context object.
64- input_tensor_tensor (torch.FloatTensor): Tensor to be quantized.
66+ input_tensor (torch.FloatTensor): Tensor to be quantized.
6567 num_bits (torch.IntTensor): Number of bit for quantization.
6668 clip_valn (torch.FloatTensor): Lower clip value bound.
6769 clip_val (torch.FloatTensor): Upper clip value bound.
6870 dequantize (bool, optional): Return dequantized or int tensor. Defaults to True.
6971 symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False.
7072 qlevel_lowering (bool, optional): Specify lowering of quantized levels.
7173 Defaults to True.
74+ axis (int, optional): Specify which tensor dimension to quantize indiviually.
75+ Defaults to 0.
7276
7377 Returns:
7478 torch.Tensor: Dequantized or Quantized output tensor.
7579 """
7680 clip_valn , clip_val = transform_clips (input_tensor .dtype , clip_valn , clip_val )
7781 n_levels , scale , zero_point = PerChannelSTE .calc_qparams (
78- input_tensor , num_bits , clip_valn , clip_val , qlevel_lowering
82+ num_bits , clip_valn , clip_val , qlevel_lowering , axis , input_tensor . shape
7983 )
8084 PerChannelSTE .save_tensors (
8185 ctx ,
@@ -100,6 +104,8 @@ def calc_qparams(
100104 clip_val : torch .FloatTensor ,
101105 symmetric : bool = False ,
102106 qlevel_lowering : bool = True ,
107+ axis : int = 0 ,
108+ tensor_shape : torch .Size = None ,
103109 ):
104110 """
105111 Compute the scale and zero_point from num_bits and clip values
@@ -111,6 +117,8 @@ def calc_qparams(
111117 symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False.
112118 qlevel_lowering (bool, optional): Specify lowering of quantized levels.
113119 Defaults to True.
120+ axis (int, optional): Specify which tensor dimension to quantize indiviually.
121+ Defaults to 0.
114122
115123 Returns:
116124 torch.IntTensor, torch.FloatTensor, torch.IntTensor: Quantized parameters
@@ -129,9 +137,24 @@ def calc_qparams(
129137 integral_zero_point = True ,
130138 signed = False ,
131139 )
132- return n_levels , scale , zero_point
140+
141+ # Broadcast scale, zero_point based on axis
142+ scale , zero_point = per_channel_axis (scale , zero_point , tensor_shape , axis )
133143
144+ return n_levels , scale , zero_point
145+
146+ # The save_tensors and backward unpacking must be synced
147+ @classmethod
148+ def save_tensors (cls , ctx , tensors ) -> None :
149+ """
150+ Save computed data to ctx for backward()
134151
152+ Args:
153+ ctx (torch.autograd.Function): Forward/Backward context object.
154+ tensors (list(torch.Tensor)): List of tensors to save.
155+ """
156+ ctx .save_for_backward (* tensors )
157+
135158 @staticmethod
136159 def backward (ctx , grad_output ):
137160 """
@@ -172,6 +195,7 @@ def forward(
172195 dequantize : bool = True ,
173196 symmetric : bool = False ,
174197 qlevel_lowering : bool = False ,
198+ axis : int = 0 ,
175199 ):
176200 """
177201 General forward method:
@@ -190,6 +214,8 @@ def forward(
190214 symmetric (bool, optional): Specify if clip values are symmetric. Defaults to False.
191215 qlevel_lowering (bool, optional): Specify lowering of quantized levels.
192216 Defaults to True.
217+ axis (int, optional): Specify which tensor dimension to quantize indiviually.
218+ Defaults to 0.
193219
194220 Returns:
195221 torch.Tensor: Dequantized or Quantized output tensor.
@@ -207,10 +233,10 @@ def forward(
207233 qint_h ,
208234 qint_dtype ,
209235 ) = PerChannelSTE_PTnative .calc_qparams (
210- num_bits , clip_valn , clip_val , symmetric , qlevel_lowering
236+ num_bits , clip_valn , clip_val , symmetric , qlevel_lowering ,
211237 )
212238 output = PerChannelSTE_PTnative .linear_quantization (
213- input_tensor , scale , zero_point , qint_l , qint_h , qint_dtype , dequantize
239+ input_tensor , scale , zero_point , qint_l , qint_h , qint_dtype , dequantize , axis
214240 )
215241 return output
216242
@@ -248,6 +274,7 @@ def calc_qparams(
248274 qint_l , qint_h , qint_dtype = PerChannelSTE_PTnative .qint_bounds (
249275 num_bits , zero_point , symmetric , qlevel_lowering
250276 )
277+ # Note: fake_quantize_per_channel_affine does not need matching dimensions for scale/zp to tensor
251278 return n_levels , scale , zero_point , qint_l , qint_h , qint_dtype
252279
253280 @classmethod
@@ -288,6 +315,7 @@ def linear_quantization(
288315 qint_h : int ,
289316 qint_dtype : torch .dtype ,
290317 dequantize : bool = True ,
318+ axis : int = 0 ,
291319 ) -> torch .Tensor :
292320 """
293321 Linear quantization for PTnative STE
@@ -300,6 +328,8 @@ def linear_quantization(
300328 qint_h (int): Quantized integer upper clip value.
301329 qint_dtype (torch.dtype): Quantized integer dtype.
302330 dequantize (bool, optional): Specify to return fp or quantized int. Defaults to True.
331+ axis (int, optional): Specify which tensor dimension to quantize indiviually.
332+ Defaults to 0.
303333
304334 Returns:
305335 torch.Tensor: PTnative quantized or dequantized tensor.
@@ -310,7 +340,7 @@ def linear_quantization(
310340 input_tensor .float (),
311341 scale .float (),
312342 zero_point ,
313- axis = 0 ,
343+ axis = axis ,
314344 quant_min = qint_l ,
315345 quant_max = qint_h ,
316346 ).to (input_tensor .dtype )
@@ -321,7 +351,7 @@ def linear_quantization(
321351 input_tensor .float (),
322352 scale .float (),
323353 zero_point ,
324- axis = 0 ,
354+ axis = axis ,
325355 dtype = qint_dtype
326356 )
327357 .int_repr ()
@@ -362,6 +392,7 @@ def forward(
362392 symmetric : bool = False ,
363393 qlevel_lowering : bool = False ,
364394 use_code : bool = False ,
395+ axis : int = 0 ,
365396 ):
366397 """
367398 General forward method:
@@ -387,7 +418,14 @@ def forward(
387418 """
388419 clip_valn , clip_val = transform_clips (input_tensor .dtype , clip_valn , clip_val )
389420 n_levels , scale , zero_point = PerChannelSTESAWB .calc_qparams (
390- input_tensor , num_bits , clip_valn , clip_val , qlevel_lowering , use_code
421+ num_bits ,
422+ clip_valn ,
423+ clip_val ,
424+ symmetric ,
425+ qlevel_lowering ,
426+ axis ,
427+ input_tensor .shape ,
428+ use_code ,
391429 )
392430 PerChannelSTE .save_tensors (
393431 ctx ,
@@ -408,10 +446,12 @@ def forward(
408446 def calc_qparams (
409447 cls ,
410448 num_bits : torch .IntTensor ,
411- clip_valn : torch .FloatTensor ,
449+ _clip_valn : torch .FloatTensor ,
412450 clip_val : torch .FloatTensor ,
413451 symmetric : bool = False ,
414452 qlevel_lowering : bool = True ,
453+ axis : int = 0 ,
454+ tensor_shape : torch .Size = None ,
415455 use_code : bool = False ,
416456 ):
417457 """
@@ -440,7 +480,10 @@ def calc_qparams(
440480 num_bits , clip_val , qlevel_lowering
441481 )
442482
483+ # Broadcast scale, zero_point to tensor shape
484+ scale , zero_point = per_channel_axis (scale , zero_point , tensor_shape , axis )
485+
443486 output = n_levels , scale , zero_point
444487 else :
445488 raise ValueError ("SAWB has non-symmetric Qscheme" )
446- return output
489+ return output
0 commit comments