@@ -3393,12 +3393,79 @@ def aten_fake_quantize_per_channel_affine_cachemask_backward(
33933393 raise NotImplementedError ()
33943394
33953395
3396+ @torch_op ("aten::fake_quantize_per_tensor_affine" , trace_only = True )
33963397def aten_fake_quantize_per_tensor_affine (
3397- self : TensorType , scale : float , zero_point : int , quant_min : int , quant_max : int
3398- ) -> TensorType :
3398+ self : TFloat ,
3399+ scale : float ,
3400+ zero_point : int ,
3401+ quant_min : int ,
3402+ quant_max : int ,
3403+ ) -> TFloat :
33993404 """fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor"""
34003405
3401- raise NotImplementedError ()
3406+ return _aten_fake_quantize_per_tensor_affine (self , scale , zero_point , quant_min , quant_max )
3407+
3408+
3409+ @torch_op ("aten::fake_quantize_per_tensor_affine.tensor_qparams" , trace_only = True )
3410+ def aten_fake_quantize_per_tensor_affine_tensor_qparams (
3411+ self : TFloat ,
3412+ scale : TReal ,
3413+ zero_point : TReal ,
3414+ quant_min : int ,
3415+ quant_max : int ,
3416+ ) -> TFloat :
3417+ """fake_quantize_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor"""
3418+
3419+ return _aten_fake_quantize_per_tensor_affine (self , scale , zero_point , quant_min , quant_max )
3420+
3421+
3422+ def _aten_fake_quantize_per_tensor_affine (
3423+ self : TFloat ,
3424+ scale : Union [float , TReal ],
3425+ zero_point : Union [int , TReal ],
3426+ quant_min : int ,
3427+ quant_max : int ,
3428+ ) -> TFloat :
3429+ # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
3430+ # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
3431+ if (quant_min , quant_max ) not in [(0 , 255 ), (- 128 , 127 ), (0 , 127 )]:
3432+ raise NotImplementedError (
3433+ "For (quant_min, quant_max), ONNX allows only "
3434+ "(0, 127), (0, 255) and (-128, 127). "
3435+ f"Got ({ quant_min } , { quant_max } )" ,
3436+ )
3437+
3438+ if quant_min == 0 :
3439+ int_dtype = ir .DataType .UINT8
3440+ else :
3441+ int_dtype = ir .DataType .INT8
3442+
3443+ # TODO: When opset >= 19, remove this cast
3444+ orig_dtype = self .type .dtype
3445+ if self .type .dtype not in {ir .DataType .FLOAT , ir .DataType .INT32 }:
3446+ self = op .Cast (self , to = ir .DataType .FLOAT )
3447+
3448+ # TODO: When opset >= 19, relex the condition for this cast
3449+ if isinstance (scale , float ) or scale .type .dtype != ir .DataType .FLOAT :
3450+ scale = op .Cast (scale , to = ir .DataType .FLOAT )
3451+
3452+ if isinstance (zero_point , int ) or zero_point .type .dtype != int_dtype :
3453+ zero_point = op .Cast (zero_point , to = int_dtype )
3454+
3455+ quantized = op .QuantizeLinear (self , scale , zero_point )
3456+
3457+ # See comment about, PyTorch-specific (0, 127) handling
3458+ if (quant_min , quant_max ) == (0 , 127 ):
3459+ const_127 = op .Cast (127 , to = int_dtype )
3460+ quantized = op .Clip (quantized , max = const_127 )
3461+
3462+ output = op .DequantizeLinear (quantized , scale , zero_point )
3463+
3464+ # TODO: When opset >= 23, remove this cast and set output_dtype on DequantizeLinear
3465+ if orig_dtype != ir .DataType .FLOAT :
3466+ output = op .Cast (output , to = orig_dtype )
3467+
3468+ return output
34023469
34033470
34043471def aten_fake_quantize_per_tensor_affine_cachemask (
0 commit comments