File tree Expand file tree Collapse file tree 1 file changed +35
-0
lines changed
Expand file tree Collapse file tree 1 file changed +35
-0
lines changed Original file line number Diff line number Diff line change @@ -330,3 +330,38 @@ def symmetric_linear_quantization_params(
330330 scale = diff / n_levels
331331 zero_point = torch .zeros_like (scale )
332332 return n_levels , scale , zero_point
333+
334+ def per_channel_axis (
335+ scale : torch .FloatTensor ,
336+ zero_point : torch .IntTensor ,
337+ tensor_shape : torch .Size ,
338+ axis : int = 0 ,
339+ ):
340+ """
341+ Change scale and zero_point to target axis dimension of input tensor.
342+ Axis values: 0 -> Nx1, 1 -> 1xN
343+
344+ Note: for Transformers, axis = 0 is desired
345+
346+ Args:
347+ scale (torch.FloatTensor): Dequantized range of a quantized integer bin.
348+ zero_point (torch.IntTensor): Quantized int bin mapping to fp 0.0.
349+ tensor_shape (torch.Size): Shape of quantized tensor
350+
351+ Returns:
352+ scale, zero_point:
353+ """
354+ if axis == 0 :
355+ scale = scale .unsqueeze (1 )
356+ zero_point = zero_point .unsqueeze (1 )
357+ elif axis == 1 :
358+ scale = scale .unsqueeze (0 )
359+ zero_point = zero_point .unsqueeze (0 )
360+ else :
361+ raise ValueError ("Axis must be 0 or 1" )
362+
363+ # Check that tensor shape axis is same as scale/zp broadcast
364+ assert tensor_shape [axis ] == scale .shape [axis ]
365+ assert tensor_shape [axis ] == zero_point .shape [axis ]
366+
367+ return scale , zero_point
You can’t perform that action at this time.
0 commit comments