2323 COMPLEX128 ,
2424 DOUBLE ,
2525 FLOAT ,
26+ FLOAT16 ,
2627 INT8 ,
2728 INT16 ,
2829 INT32 ,
@@ -3317,17 +3318,58 @@ def aten_eye(n: int) -> TensorType:
33173318 raise NotImplementedError ()
33183319
33193320
3321+ @torch_op ("aten::fake_quantize_per_channel_affine" , trace_only = True )
33203322def aten_fake_quantize_per_channel_affine (
3321- self : TensorType ,
3322- scale : TensorType ,
3323- zero_point : TensorType ,
3323+ self : TFloat ,
3324+ scale : FLOAT , # float32 specifically!
3325+ zero_point : Union [ INT32 , FLOAT , FLOAT16 ], # int32, float32 or float16 only!
33243326 axis : int ,
33253327 quant_min : int ,
33263328 quant_max : int ,
33273329) -> TensorType :
33283330 """fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor"""
33293331
3330- raise NotImplementedError ()
3332+ # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
3333+ # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
3334+ if (quant_min , quant_max ) not in [(0 , 255 ), (- 128 , 127 ), (0 , 127 )]:
3335+ raise NotImplementedError (
3336+ "For (quant_min, quant_max), ONNX allows only "
3337+ "(0, 127), (0, 255) and (-128, 127). "
3338+ f"Got ({ quant_min } , { quant_max } )" ,
3339+ )
3340+
3341+ if quant_min == 0 :
3342+ int_dtype = ir .DataType .UINT8
3343+ else :
3344+ int_dtype = ir .DataType .INT8
3345+
3346+ # TODO: When opset >= 19, remove this cast
3347+ orig_dtype = self .type .dtype
3348+ if self .type .dtype not in {ir .DataType .FLOAT , ir .DataType .INT32 }:
3349+ self = op .Cast (self , to = ir .DataType .FLOAT )
3350+
3351+ if zero_point .type .dtype == ir .DataType .INT32 :
3352+ zero_point = op .Cast (zero_point , to = int_dtype )
3353+ else :
3354+ raise NotImplementedError (
3355+ "ONNX only supports integer values for the zero_point parameter. "
3356+ f"Got { zero_point .type .dtype } " ,
3357+ )
3358+
3359+ quantized = op .QuantizeLinear (self , scale , zero_point , axis = axis )
3360+
3361+ # See comment about, PyTorch-specific (0, 127) handling
3362+ if (quant_min , quant_max ) == (0 , 127 ):
3363+ const_127 = op .Cast (127 , to = int_dtype )
3364+ quantized = op .Clip (quantized , max = const_127 )
3365+
3366+ output = op .DequantizeLinear (quantized , scale , zero_point , axis = axis )
3367+
3368+ # TODO: When opset >= 23, remove this cast and set output_dtype on DequantizeLinear
3369+ if orig_dtype != ir .DataType .FLOAT :
3370+ output = op .Cast (output , to = orig_dtype )
3371+
3372+ return output
33313373
33323374
33333375def aten_fake_quantize_per_channel_affine_cachemask (
@@ -3351,12 +3393,79 @@ def aten_fake_quantize_per_channel_affine_cachemask_backward(
33513393 raise NotImplementedError ()
33523394
33533395
3396+ @torch_op ("aten::fake_quantize_per_tensor_affine" , trace_only = True )
33543397def aten_fake_quantize_per_tensor_affine (
3355- self : TensorType , scale : float , zero_point : int , quant_min : int , quant_max : int
3356- ) -> TensorType :
3398+ self : TFloat ,
3399+ scale : float ,
3400+ zero_point : int ,
3401+ quant_min : int ,
3402+ quant_max : int ,
3403+ ) -> TFloat :
33573404 """fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor"""
33583405
3359- raise NotImplementedError ()
3406+ return _aten_fake_quantize_per_tensor_affine (self , scale , zero_point , quant_min , quant_max )
3407+
3408+
3409+ @torch_op ("aten::fake_quantize_per_tensor_affine.tensor_qparams" , trace_only = True )
3410+ def aten_fake_quantize_per_tensor_affine_tensor_qparams (
3411+ self : TFloat ,
3412+ scale : TReal ,
3413+ zero_point : TReal ,
3414+ quant_min : int ,
3415+ quant_max : int ,
3416+ ) -> TFloat :
3417+ """fake_quantize_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor"""
3418+
3419+ return _aten_fake_quantize_per_tensor_affine (self , scale , zero_point , quant_min , quant_max )
3420+
3421+
3422+ def _aten_fake_quantize_per_tensor_affine (
3423+ self : TFloat ,
3424+ scale : Union [float , TReal ],
3425+ zero_point : Union [int , TReal ],
3426+ quant_min : int ,
3427+ quant_max : int ,
3428+ ) -> TFloat :
3429+ # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
3430+ # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
3431+ if (quant_min , quant_max ) not in [(0 , 255 ), (- 128 , 127 ), (0 , 127 )]:
3432+ raise NotImplementedError (
3433+ "For (quant_min, quant_max), ONNX allows only "
3434+ "(0, 127), (0, 255) and (-128, 127). "
3435+ f"Got ({ quant_min } , { quant_max } )" ,
3436+ )
3437+
3438+ if quant_min == 0 :
3439+ int_dtype = ir .DataType .UINT8
3440+ else :
3441+ int_dtype = ir .DataType .INT8
3442+
3443+ # TODO: When opset >= 19, remove this cast
3444+ orig_dtype = self .type .dtype
3445+ if self .type .dtype not in {ir .DataType .FLOAT , ir .DataType .INT32 }:
3446+ self = op .Cast (self , to = ir .DataType .FLOAT )
3447+
3448+ # TODO: When opset >= 19, relex the condition for this cast
3449+ if isinstance (scale , float ) or scale .type .dtype != ir .DataType .FLOAT :
3450+ scale = op .Cast (scale , to = ir .DataType .FLOAT )
3451+
3452+ if isinstance (zero_point , int ) or zero_point .type .dtype != int_dtype :
3453+ zero_point = op .Cast (zero_point , to = int_dtype )
3454+
3455+ quantized = op .QuantizeLinear (self , scale , zero_point )
3456+
3457+ # See comment about, PyTorch-specific (0, 127) handling
3458+ if (quant_min , quant_max ) == (0 , 127 ):
3459+ const_127 = op .Cast (127 , to = int_dtype )
3460+ quantized = op .Clip (quantized , max = const_127 )
3461+
3462+ output = op .DequantizeLinear (quantized , scale , zero_point )
3463+
3464+ # TODO: When opset >= 23, remove this cast and set output_dtype on DequantizeLinear
3465+ if orig_dtype != ir .DataType .FLOAT :
3466+ output = op .Cast (output , to = orig_dtype )
3467+
3468+ return output
33603469
33613470
33623471def aten_fake_quantize_per_tensor_affine_cachemask (
0 commit comments