Skip to content

Commit 0173209

Browse files
committed
fix: QminmaxPerCh_PTnative patch for unstable rounding in torch.quantize_per_channel
Signed-off-by: Brandon Groth <[email protected]>
1 parent fbf94e0 commit 0173209

File tree

4 files changed

+16
-9
lines changed

4 files changed

+16
-9
lines changed

fms_mo/quant_refactor/linear_utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -293,14 +293,10 @@ def asymmetric_linear_quantization_params(
293293
scale = diff / n_levels
294294
zero_point = -sat_min / scale
295295
if integral_zero_point:
296-
zero_point = zero_point.round()
296+
zero_point = zero_point.round().to(torch.int)
297297
if signed:
298298
zero_point += 2 ** (num_bits - 1)
299299

300-
# Ensure zp in [0, n_levels]
301-
zp_bounds = torch.all((zero_point > 0) & (zero_point < n_levels))
302-
assert zp_bounds, "Asymmetric zero points should be in [0, 2**bits-1]"
303-
304300
return n_levels, scale, zero_point
305301

306302

fms_mo/quant_refactor/per_channel_ste.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,6 @@ def linear_quantization(
361361
quant_max=qint_h,
362362
).to(input_tensor.dtype)
363363
else:
364-
# Note: scale is multi-valued, but zero_point isn't...
365364
output = (
366365
torch.quantize_per_channel(
367366
input_tensor.float(),

tests/quantizers/test_qmax.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,11 +713,16 @@ def test_qmaxnew_asymmetric_perCh(
713713

714714
setup = torch_quantizer_asymmetric_perCh.get_setup()
715715

716+
# QminmaxPerChSTE_PTnative has a rare numerical problem:
717+
# input/scale + zp == (K+.5), then rounding becomes unstable inside torch.quantize_per_channel
718+
error_override = base_options["nativePT"] and other_options_perCh["minmax"]
719+
716720
quantizer_error(
717721
tensor,
718722
qtensor_fms_mo,
719723
qtensor_torch,
720724
setup,
721725
base_options,
722726
other_options_perCh,
727+
error_override=error_override,
723728
)

tests/quantizers/test_quantizer_utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def quantizer_error(
7979
max_norm_tol=1e-5,
8080
l2_norm_tol=1e-2,
8181
nonzero_tol=1e-2,
82+
error_override=False,
8283
):
8384
"""
8485
Check various types of quantizer numerical errors for FMS and Torch quantizied tensors
@@ -102,7 +103,7 @@ def quantizer_error(
102103
"""
103104

104105
# If using PyTorch functions, set error tolerances to zero
105-
if base_options["nativePT"]:
106+
if base_options["nativePT"] and not error_override:
106107
max_norm_tol = 0.0
107108
l2_norm_tol = 0.0
108109
nonzero_tol = 0.0
@@ -207,16 +208,16 @@ def quantizer_error(
207208
num_bits,
208209
clip_low,
209210
clip_high,
211+
_n_level,
210212
scale,
211213
_zero_point,
212-
_n_level,
213214
_quant_min,
214215
_quant_max,
215216
_qscheme,
216217
) = setup
217218

218219
# Check if qtensors are constant for non-constant tensor with appropriate spacing of elements
219-
if tensor.unique().numel() > 1 and (tensor.max() - tensor.min()) > scale:
220+
if tensor.unique().numel() > 1 and (tensor.max() - tensor.min()) > scale.min():
220221
fms_mo_unique_vals = qtensor_fms_mo.unique()
221222
torch_unique_vals = qtensor_torch.unique()
222223

@@ -271,6 +272,8 @@ def quantizer_error(
271272

272273
assert total_nonscale_nonzero_indices == total_nonzero_indices - total_scale_indices
273274

275+
# At this point, we don't want to count any potential problems from banker's rounding
276+
274277
with torch.no_grad():
275278
try:
276279
# Check for large difference in values for current dtype (ie underflow/overflow)
@@ -322,6 +325,10 @@ def quantizer_error(
322325
)
323326
logger.error("Total Diff vals =%s", diff.unique().numel())
324327
logger.error("Diff unique vals =\n%s", diff.unique().detach())
328+
329+
logger.error("input tensor =\n%s", tensor)
330+
logger.error("torch_scale =\n%s", scale)
331+
logger.error("torch_zero_point =\n%s", _zero_point)
325332
raise e_value # Reraise exception
326333

327334

0 commit comments

Comments
 (0)