Skip to content

Commit 92d0b73

Browse files
committed
implement onnx conversion for aten::fake_quantize_per_tensor_affine
1 parent 9e5fa4e commit 92d0b73

File tree

1 file changed

+70
-3
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+70
-3
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3393,12 +3393,79 @@ def aten_fake_quantize_per_channel_affine_cachemask_backward(
33933393
raise NotImplementedError()
33943394

33953395

3396+
@torch_op("aten::fake_quantize_per_tensor_affine", trace_only=True)
33963397
def aten_fake_quantize_per_tensor_affine(
3397-
self: TensorType, scale: float, zero_point: int, quant_min: int, quant_max: int
3398-
) -> TensorType:
3398+
self: TFloat,
3399+
scale: float,
3400+
zero_point: int,
3401+
quant_min: int,
3402+
quant_max: int,
3403+
) -> TFloat:
33993404
"""fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor"""
34003405

3401-
raise NotImplementedError()
3406+
return _aten_fake_quantize_per_tensor_affine(self, scale, zero_point, quant_min, quant_max)
3407+
3408+
3409+
@torch_op("aten::fake_quantize_per_tensor_affine.tensor_qparams", trace_only=True)
3410+
def aten_fake_quantize_per_tensor_affine_tensor_qparams(
3411+
self: TFloat,
3412+
scale: TReal,
3413+
zero_point: TReal,
3414+
quant_min: int,
3415+
quant_max: int,
3416+
) -> TFloat:
3417+
"""fake_quantize_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor"""
3418+
3419+
return _aten_fake_quantize_per_tensor_affine(self, scale, zero_point, quant_min, quant_max)
3420+
3421+
3422+
def _aten_fake_quantize_per_tensor_affine(
3423+
self: TFloat,
3424+
scale: Union[float, TReal],
3425+
zero_point: Union[int, TReal],
3426+
quant_min: int,
3427+
quant_max: int,
3428+
) -> TFloat:
3429+
# NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
3430+
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
3431+
if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
3432+
raise NotImplementedError(
3433+
"For (quant_min, quant_max), ONNX allows only "
3434+
"(0, 127), (0, 255) and (-128, 127). "
3435+
f"Got ({quant_min}, {quant_max})",
3436+
)
3437+
3438+
if quant_min == 0:
3439+
int_dtype = ir.DataType.UINT8
3440+
else:
3441+
int_dtype = ir.DataType.INT8
3442+
3443+
# TODO: When opset >= 19, remove this cast
3444+
orig_dtype = self.type.dtype
3445+
if self.type.dtype not in {ir.DataType.FLOAT, ir.DataType.INT32}:
3446+
self = op.Cast(self, to=ir.DataType.FLOAT)
3447+
3448+
# TODO: When opset >= 19, relex the condition for this cast
3449+
if isinstance(scale, float) or scale.type.dtype != ir.DataType.FLOAT:
3450+
scale = op.Cast(scale, to=ir.DataType.FLOAT)
3451+
3452+
if isinstance(zero_point, int) or zero_point.type.dtype != int_dtype:
3453+
zero_point = op.Cast(zero_point, to=int_dtype)
3454+
3455+
quantized = op.QuantizeLinear(self, scale, zero_point)
3456+
3457+
# See comment about, PyTorch-specific (0, 127) handling
3458+
if (quant_min, quant_max) == (0, 127):
3459+
const_127 = op.Cast(127, to=int_dtype)
3460+
quantized = op.Clip(quantized, max=const_127)
3461+
3462+
output = op.DequantizeLinear(quantized, scale, zero_point)
3463+
3464+
# TODO: When opset >= 23, remove this cast and set output_dtype on DequantizeLinear
3465+
if orig_dtype != ir.DataType.FLOAT:
3466+
output = op.Cast(output, to=orig_dtype)
3467+
3468+
return output
34023469

34033470

34043471
def aten_fake_quantize_per_tensor_affine_cachemask(

0 commit comments

Comments
 (0)