Skip to content

Commit 6247ac1

Browse files
authored
Implement ONNX export for fake_quantize_per_*_affine (#2696)
See pytorch/pytorch#167063. ~The tests in this PR rely on pytorch/pytorch#167465.~
1 parent 53af800 commit 6247ac1

File tree

3 files changed

+246
-7
lines changed

3 files changed

+246
-7
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 116 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
COMPLEX128,
2424
DOUBLE,
2525
FLOAT,
26+
FLOAT16,
2627
INT8,
2728
INT16,
2829
INT32,
@@ -3317,17 +3318,58 @@ def aten_eye(n: int) -> TensorType:
33173318
raise NotImplementedError()
33183319

33193320

3321+
@torch_op("aten::fake_quantize_per_channel_affine", trace_only=True)
33203322
def aten_fake_quantize_per_channel_affine(
3321-
self: TensorType,
3322-
scale: TensorType,
3323-
zero_point: TensorType,
3323+
self: TFloat,
3324+
scale: FLOAT, # float32 specifically!
3325+
zero_point: Union[INT32, FLOAT, FLOAT16], # int32, float32 or float16 only!
33243326
axis: int,
33253327
quant_min: int,
33263328
quant_max: int,
33273329
) -> TensorType:
33283330
"""fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor"""
33293331

3330-
raise NotImplementedError()
3332+
# NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
3333+
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
3334+
if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
3335+
raise NotImplementedError(
3336+
"For (quant_min, quant_max), ONNX allows only "
3337+
"(0, 127), (0, 255) and (-128, 127). "
3338+
f"Got ({quant_min}, {quant_max})",
3339+
)
3340+
3341+
if quant_min == 0:
3342+
int_dtype = ir.DataType.UINT8
3343+
else:
3344+
int_dtype = ir.DataType.INT8
3345+
3346+
# TODO: When opset >= 19, remove this cast
3347+
orig_dtype = self.type.dtype
3348+
if self.type.dtype not in {ir.DataType.FLOAT, ir.DataType.INT32}:
3349+
self = op.Cast(self, to=ir.DataType.FLOAT)
3350+
3351+
if zero_point.type.dtype == ir.DataType.INT32:
3352+
zero_point = op.Cast(zero_point, to=int_dtype)
3353+
else:
3354+
raise NotImplementedError(
3355+
"ONNX only supports integer values for the zero_point parameter. "
3356+
f"Got {zero_point.type.dtype}",
3357+
)
3358+
3359+
quantized = op.QuantizeLinear(self, scale, zero_point, axis=axis)
3360+
3361+
# See comment about, PyTorch-specific (0, 127) handling
3362+
if (quant_min, quant_max) == (0, 127):
3363+
const_127 = op.Cast(127, to=int_dtype)
3364+
quantized = op.Clip(quantized, max=const_127)
3365+
3366+
output = op.DequantizeLinear(quantized, scale, zero_point, axis=axis)
3367+
3368+
# TODO: When opset >= 23, remove this cast and set output_dtype on DequantizeLinear
3369+
if orig_dtype != ir.DataType.FLOAT:
3370+
output = op.Cast(output, to=orig_dtype)
3371+
3372+
return output
33313373

33323374

33333375
def aten_fake_quantize_per_channel_affine_cachemask(
@@ -3351,12 +3393,79 @@ def aten_fake_quantize_per_channel_affine_cachemask_backward(
33513393
raise NotImplementedError()
33523394

33533395

3396+
@torch_op("aten::fake_quantize_per_tensor_affine", trace_only=True)
33543397
def aten_fake_quantize_per_tensor_affine(
3355-
self: TensorType, scale: float, zero_point: int, quant_min: int, quant_max: int
3356-
) -> TensorType:
3398+
self: TFloat,
3399+
scale: float,
3400+
zero_point: int,
3401+
quant_min: int,
3402+
quant_max: int,
3403+
) -> TFloat:
33573404
"""fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor"""
33583405

3359-
raise NotImplementedError()
3406+
return _aten_fake_quantize_per_tensor_affine(self, scale, zero_point, quant_min, quant_max)
3407+
3408+
3409+
@torch_op("aten::fake_quantize_per_tensor_affine.tensor_qparams", trace_only=True)
3410+
def aten_fake_quantize_per_tensor_affine_tensor_qparams(
3411+
self: TFloat,
3412+
scale: TReal,
3413+
zero_point: TReal,
3414+
quant_min: int,
3415+
quant_max: int,
3416+
) -> TFloat:
3417+
"""fake_quantize_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor"""
3418+
3419+
return _aten_fake_quantize_per_tensor_affine(self, scale, zero_point, quant_min, quant_max)
3420+
3421+
3422+
def _aten_fake_quantize_per_tensor_affine(
3423+
self: TFloat,
3424+
scale: Union[float, TReal],
3425+
zero_point: Union[int, TReal],
3426+
quant_min: int,
3427+
quant_max: int,
3428+
) -> TFloat:
3429+
# NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
3430+
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
3431+
if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
3432+
raise NotImplementedError(
3433+
"For (quant_min, quant_max), ONNX allows only "
3434+
"(0, 127), (0, 255) and (-128, 127). "
3435+
f"Got ({quant_min}, {quant_max})",
3436+
)
3437+
3438+
if quant_min == 0:
3439+
int_dtype = ir.DataType.UINT8
3440+
else:
3441+
int_dtype = ir.DataType.INT8
3442+
3443+
# TODO: When opset >= 19, remove this cast
3444+
orig_dtype = self.type.dtype
3445+
if self.type.dtype not in {ir.DataType.FLOAT, ir.DataType.INT32}:
3446+
self = op.Cast(self, to=ir.DataType.FLOAT)
3447+
3448+
# TODO: When opset >= 19, relex the condition for this cast
3449+
if isinstance(scale, float) or scale.type.dtype != ir.DataType.FLOAT:
3450+
scale = op.Cast(scale, to=ir.DataType.FLOAT)
3451+
3452+
if isinstance(zero_point, int) or zero_point.type.dtype != int_dtype:
3453+
zero_point = op.Cast(zero_point, to=int_dtype)
3454+
3455+
quantized = op.QuantizeLinear(self, scale, zero_point)
3456+
3457+
# See comment about, PyTorch-specific (0, 127) handling
3458+
if (quant_min, quant_max) == (0, 127):
3459+
const_127 = op.Cast(127, to=int_dtype)
3460+
quantized = op.Clip(quantized, max=const_127)
3461+
3462+
output = op.DequantizeLinear(quantized, scale, zero_point)
3463+
3464+
# TODO: When opset >= 23, remove this cast and set output_dtype on DequantizeLinear
3465+
if orig_dtype != ir.DataType.FLOAT:
3466+
output = op.Cast(output, to=orig_dtype)
3467+
3468+
return output
33603469

33613470

33623471
def aten_fake_quantize_per_tensor_affine_cachemask(

tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,109 @@ def sample_inputs__fft_c2r(self, device, dtype, requires_grad=False, **_):
779779
)
780780

781781

782+
def sample_inputs_fake_quantize_per_tensor_affine(
783+
op_info, device, dtype, requires_grad, **kwargs
784+
):
785+
del op_info, kwargs # Unused
786+
make_arg = functools.partial(
787+
opinfo_core.make_tensor,
788+
device=device,
789+
requires_grad=requires_grad,
790+
)
791+
792+
# Test 1D, empty and scalar tensors (like sample_inputs_elementwise_unary)
793+
shapes = [
794+
(S,),
795+
(1, 0, 3),
796+
(),
797+
]
798+
799+
scale_zero_point_dtypes = [
800+
# default (float, int)
801+
(None, None)
802+
] + [
803+
# tensor_qparams (tensor, tensor)
804+
(t1, t2)
805+
for t1 in common_dtype.all_types_and()
806+
for t2 in common_dtype.all_types_and()
807+
]
808+
809+
# NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
810+
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
811+
quant_vals = [(0, 255), (-128, 127), (0, 127)]
812+
813+
cases = itertools.product(shapes, scale_zero_point_dtypes, quant_vals)
814+
for shape, (scale_dtype, zero_point_dtype), (quant_min, quant_max) in cases:
815+
scale = make_arg(
816+
(),
817+
dtype=scale_dtype or torch.float64,
818+
)
819+
if scale_dtype is None:
820+
scale = scale.item()
821+
822+
zero_point = make_arg(
823+
(),
824+
dtype=zero_point_dtype or torch.int64,
825+
# zero_point must be between quant_min and quant_max
826+
low=quant_min,
827+
high=quant_max,
828+
)
829+
if zero_point_dtype is None:
830+
zero_point = zero_point.item()
831+
832+
args = (scale, zero_point, quant_min, quant_max)
833+
yield opinfo_core.SampleInput(make_arg(shape, dtype=dtype), args=args)
834+
835+
836+
def sample_inputs_fake_quantize_per_channel_affine(
837+
op_info, device, dtype, requires_grad, **kwargs
838+
):
839+
del op_info, kwargs # Unused
840+
make_arg = functools.partial(
841+
opinfo_core.make_tensor,
842+
device=device,
843+
requires_grad=requires_grad,
844+
)
845+
846+
# Test 1D, 2D, 4D and empty tensors (scalar tensors not supported)
847+
axes_and_shapes = [
848+
# 1D, 2D, 4D
849+
(axis, (S,) * dims)
850+
for dims in (1, 2, 4)
851+
for axis in range(dims)
852+
] + [
853+
# empty
854+
(0, (1, 0, 3)),
855+
(2, (1, 0, 3)),
856+
# empty channel axis causes an error due to
857+
# an internal zero_point.min() calculation
858+
# (1, (1, 0, 3)),
859+
]
860+
861+
# tensor_qparams
862+
scale_dtype = torch.float
863+
zero_point_dtypes = [torch.int32, torch.float, torch.half]
864+
865+
# NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
866+
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
867+
quant_vals = [(0, 255), (-128, 127), (0, 127)]
868+
869+
cases = itertools.product(axes_and_shapes, zero_point_dtypes, quant_vals)
870+
for (axis, shape), zero_point_dtype, (quant_min, quant_max) in cases:
871+
scale = make_arg((shape[axis],), dtype=scale_dtype)
872+
873+
zero_point = make_arg(
874+
(shape[axis],),
875+
dtype=zero_point_dtype or torch.int64,
876+
# zero_point must be between quant_min and quant_max
877+
low=quant_min,
878+
high=quant_max,
879+
)
880+
881+
args = (scale, zero_point, axis, quant_min, quant_max)
882+
yield opinfo_core.SampleInput(make_arg(shape, dtype=dtype), args=args)
883+
884+
782885
def _index_variable_bool(shape, max_indices, device):
783886
if not isinstance(shape, tuple):
784887
shape = (shape,)
@@ -2408,6 +2511,22 @@ def __init__(self):
24082511
sample_inputs_func=sample_inputs__fft_r2c,
24092512
supports_out=False,
24102513
),
2514+
opinfo_core.OpInfo(
2515+
"ops.aten.fake_quantize_per_tensor_affine",
2516+
aten_name="fake_quantize_per_tensor_affine",
2517+
op=torch.fake_quantize_per_tensor_affine,
2518+
dtypes=common_dtype.floating_types_and(torch.half, torch.bfloat16),
2519+
sample_inputs_func=sample_inputs_fake_quantize_per_tensor_affine,
2520+
supports_out=False,
2521+
),
2522+
opinfo_core.OpInfo(
2523+
"ops.aten.fake_quantize_per_channel_affine",
2524+
aten_name="fake_quantize_per_channel_affine",
2525+
op=torch.fake_quantize_per_channel_affine,
2526+
dtypes=common_dtype.floating_types_and(torch.half, torch.bfloat16),
2527+
sample_inputs_func=sample_inputs_fake_quantize_per_channel_affine,
2528+
supports_out=False,
2529+
),
24112530
opinfo_core.BinaryUfuncInfo(
24122531
"ops.aten.floor_divide",
24132532
aten_name="floor_divide",

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,17 @@ def _where_input_wrangler(
698698
TorchLibOpInfo("special.erfcx", special_ops.aten_special_erfcx).xfail(
699699
reason="fixme: The implementation is numerically unstable: https://github.com/microsoft/onnxscript/issues/1223"
700700
),
701+
TorchLibOpInfo(
702+
"ops.aten.fake_quantize_per_channel_affine",
703+
core_ops.aten_fake_quantize_per_channel_affine,
704+
).xfail(
705+
reason="fixme: ONNX (De)QuantizeLinear only supports integer zero_point values",
706+
matcher=lambda sample: sample.args[1].dtype != torch.int32,
707+
),
708+
TorchLibOpInfo(
709+
"ops.aten.fake_quantize_per_tensor_affine",
710+
core_ops.aten_fake_quantize_per_tensor_affine,
711+
),
701712
TorchLibOpInfo("fill", core_ops.aten_fill),
702713
TorchLibOpInfo("flip", core_ops.aten_flip).skip(
703714
reason="fixme: size 0 inputs are not handled yet",

0 commit comments

Comments
 (0)