Skip to content

Commit 3e8dbce

Browse files
committed
fix: torch_quantizer perCh changes
Signed-off-by: Brandon Groth <[email protected]>
1 parent b77d0bf commit 3e8dbce

File tree

1 file changed

+54
-19
lines changed

1 file changed

+54
-19
lines changed

fms_mo/quant_refactor/torch_quantizer.py

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,14 @@ def __init__(
7373
num_bits.item() if isinstance(num_bits, torch.Tensor) else num_bits
7474
)
7575
# turn clips into tensors (from python float)
76-
self.clip_low = torch.Tensor([clip_low])
77-
self.clip_high = torch.Tensor([clip_high])
76+
self.clip_low = (
77+
torch.Tensor([clip_low]) if not isinstance(clip_low, torch.Tensor)
78+
else clip_low
79+
)
80+
self.clip_high = (
81+
torch.Tensor([clip_high]) if not isinstance(clip_high, torch.Tensor)
82+
else clip_high
83+
)
7884
self.symmetric_zp0 = False
7985
self.qscheme = qscheme
8086
self.set_quant_bounds()
@@ -102,11 +108,11 @@ def get_setup(self):
102108
"""
103109
return (
104110
self.num_bits_int,
105-
self.clip_low.item(),
106-
self.clip_high.item(),
111+
self.clip_low,
112+
self.clip_high,
107113
self.n_levels.item(),
108-
self.scale.item(),
109-
self.zero_point.item(),
114+
self.scale,
115+
self.zero_point,
110116
self.quant_min,
111117
self.quant_max,
112118
self.qscheme,
@@ -127,7 +133,7 @@ def set_quant_bounds(self):
127133
self.scale = (self.clip_high - self.clip_low) / (self.n_levels)
128134
# this "ZP" will map the float value we choose (clip_low in this case) to the 0 bin
129135
self.zero_point = (
130-
torch.tensor(0)
136+
torch.zeros(self.scale.shape, dtype=torch.int)
131137
if (self.is_symmetric)
132138
else torch.round(-self.clip_low / self.scale).to(torch.int)
133139
)
@@ -138,7 +144,7 @@ def set_quant_range(self):
138144
"""
139145
Set quantization integer range based on member variables
140146
"""
141-
if self.is_symmetric and self.zero_point == 0:
147+
if self.is_symmetric and torch.sum(self.zero_point) == 0:
142148
# Either [-8,7];[-128,127] for non-symmetric or [-7,7];[-127,127] for qlevel_lowering
143149
self.quant_min, self.quant_max = (
144150
-(2 ** (self.num_bits - 1)) + self.symmetric_nlevel,
@@ -264,7 +270,7 @@ def get_torch_dtype(self):
264270
if self.is_single_sided:
265271
signed = False
266272
else:
267-
signed = (self.zero_point == 0).item()
273+
signed = (torch.sum(self.zero_point) == 0).item()
268274
return self.dtype_dict.get(
269275
(self.num_bits_int, signed)
270276
) # NOTE .item() won't work for perCh
@@ -282,24 +288,53 @@ def forward(self, tensor: torch.FloatTensor):
282288
Returns:
283289
torch.FloatTensor: Quantized or dequantized tensor.
284290
"""
291+
285292
if self.dequantize:
286-
output = torch.fake_quantize_per_tensor_affine(
287-
tensor, self.scale, self.zero_point, self.quant_min, self.quant_max
288-
)
293+
if self.qscheme.Nch: # Per Channel
294+
output = torch.fake_quantize_per_channel_affine(
295+
tensor,
296+
self.scale.float(),
297+
self.zero_point.float(),
298+
self.qscheme.axis,
299+
self.quant_min,
300+
self.quant_max,
301+
)
302+
elif self.qscheme.Ngrp: # Per Group
303+
pass
304+
else: # Per Tensor
305+
output = torch.fake_quantize_per_tensor_affine(
306+
tensor,
307+
self.scale,
308+
self.zero_point,
309+
self.quant_min,
310+
self.quant_max,
311+
)
289312
else:
290313
dtype = self.get_torch_dtype()
291314
if dtype:
292-
# Clamp to [quant_min, quant_max] in case we are storing int4 into int8/uint8 tensor
293-
output = (
294-
torch.quantize_per_tensor(
295-
tensor, self.scale, self.zero_point, dtype
315+
if self.qscheme.q_unit == "perCh":
316+
output = torch.quantize_per_channel(
317+
tensor,
318+
self.scale,
319+
self.zero_point,
320+
self.qscheme.axis,
321+
dtype,
296322
)
297-
.int_repr()
298-
.clamp(self.quant_min, self.quant_max)
299-
)
323+
elif self.qscheme.q_unit == "perGrp":
324+
raise RuntimeError("TorchQuantizer forward not implemented for perGrp")
325+
else: # Per Tensor
326+
output = torch.quantize_per_tensor(
327+
tensor,
328+
self.scale,
329+
self.zero_point,
330+
dtype,
331+
)
332+
# Clamp required if storing int4 into int8 tensor (no PT support for int4)
333+
output = output.int_repr().clamp(self.quant_min, self.quant_max)
300334
else:
301335
raise RuntimeError(
302336
f"num_bits {self.num_bits} and sign {(self.zero_point==0).item()}"
303337
"combination results in unavailable dtype."
304338
)
339+
305340
return output

0 commit comments

Comments
 (0)