Skip to content

Commit fe8cd9d

Browse files
committed
feat: Added per_channel_axis
Signed-off-by: Brandon Groth <[email protected]>
1 parent 3bd9401 commit fe8cd9d

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

fms_mo/quant_refactor/linear_utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,3 +330,38 @@ def symmetric_linear_quantization_params(
330330
scale = diff / n_levels
331331
zero_point = torch.zeros_like(scale)
332332
return n_levels, scale, zero_point
333+
334+
def per_channel_axis(
335+
scale: torch.FloatTensor,
336+
zero_point: torch.IntTensor,
337+
tensor_shape: torch.Size,
338+
axis: int = 0,
339+
):
340+
"""
341+
Change scale and zero_point to target axis dimension of input tensor.
342+
Axis values: 0 -> Nx1, 1 -> 1xN
343+
344+
Note: for Transformers, axis = 0 is desired
345+
346+
Args:
347+
scale (torch.FloatTensor): Dequantized range of a quantized integer bin.
348+
zero_point (torch.IntTensor): Quantized int bin mapping to fp 0.0.
349+
tensor_shape (torch.Size): Shape of quantized tensor
350+
351+
Returns:
352+
scale, zero_point:
353+
"""
354+
if axis == 0:
355+
scale = scale.unsqueeze(1)
356+
zero_point = zero_point.unsqueeze(1)
357+
elif axis == 1:
358+
scale = scale.unsqueeze(0)
359+
zero_point = zero_point.unsqueeze(0)
360+
else:
361+
raise ValueError("Axis must be 0 or 1")
362+
363+
# Check that tensor shape axis is same as scale/zp broadcast
364+
assert tensor_shape[axis] == scale.shape[axis]
365+
assert tensor_shape[axis] == zero_point.shape[axis]
366+
367+
return scale, zero_point

0 commit comments

Comments
 (0)