11from torch .onnx import register_custom_op_symbolic
22
33# Register symbolic op for torch.quantize_function op.
4+ import functools
5+ from torch .onnx ._internal import jit_utils , registration
6+ _onnx_symbolic = functools .partial (registration .onnx_symbolic , opset = 13 )
7+ from torch .onnx import (
8+ _type_utils ,
9+ symbolic_helper ,
10+ symbolic_opset9 as opset9 ,
11+ )
12+ import torch ._C ._onnx as _C_onnx
13+ import torch
14+ @_onnx_symbolic ("aten::fake_quantize_per_tensor_affine" )
15+ @symbolic_helper .parse_args ("v" , "v" , "v" , "i" , "i" )
16+ def fake_quantize_per_tensor_affine (
17+ g : jit_utils .GraphContext ,
18+ inputs ,
19+ scale ,
20+ zero_point ,
21+ quant_min = - 128 ,
22+ quant_max = 127 ,
23+ ):
24+ # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
25+ # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
26+ # if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
27+ # raise errors.SymbolicValueError(
28+ # "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). "
29+ # f"Got ({quant_min}, {quant_max})",
30+ # inputs,
31+ # )
32+ if quant_min == 0 :
33+ zero_point = g .op ("Cast" , zero_point , to_i = _C_onnx .TensorProtoDataType .UINT8 )
34+ else :
35+ zero_point = g .op ("Cast" , zero_point , to_i = _C_onnx .TensorProtoDataType .INT8 )
36+ if (
37+ _type_utils .JitScalarType .from_value (scale , _type_utils .JitScalarType .UNDEFINED )
38+ != _type_utils .JitScalarType .FLOAT
39+ ):
40+ scale = g .op ("Cast" , scale , to_i = _C_onnx .TensorProtoDataType .FLOAT )
41+ quantized = g .op ("QuantizeLinear" , inputs , scale , zero_point )
42+ if (quant_min , quant_max ) == (0 , 127 ):
43+ quantized = g .op (
44+ "Clip" ,
45+ quantized ,
46+ opset9 .unused (g ),
47+ g .op ("Constant" , value_t = torch .tensor (127 , dtype = torch .uint8 )),
48+ )
49+ return g .op ("DequantizeLinear" , quantized , scale , zero_point )
450
5- def _fake_quantize_learnable_per_tensor_affine (g , x , scale , zero_point , quant_min , quant_max , grad_factor ):
6- return g .op ("::LearnablePerTensorAffine" , x , scale , zero_point , quant_min , quant_max )
51+ @_onnx_symbolic ("aten::fake_quantize_per_channel_affine" )
52+ @symbolic_helper .parse_args ("v" , "v" , "v" , "i" , "i" , "i" )
53+ def fake_quantize_per_channel_affine (
54+ g : jit_utils .GraphContext ,
55+ inputs ,
56+ scale ,
57+ zero_point ,
58+ axis ,
59+ quant_min = - 128 ,
60+ quant_max = 127 ,
61+ ):
62+ # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
63+ # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
64+ # if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
65+ # raise errors.SymbolicValueError(
66+ # "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). "
67+ # f"Got ({quant_min}, {quant_max})",
68+ # inputs,
69+ # )
70+ # ONNX defines zero_point to be int8 or uint8
71+ if quant_min == 0 :
72+ zero_point = g .op ("Cast" , zero_point , to_i = _C_onnx .TensorProtoDataType .UINT8 )
73+ else :
74+ zero_point = g .op ("Cast" , zero_point , to_i = _C_onnx .TensorProtoDataType .INT8 )
75+ quantized = g .op ("QuantizeLinear" , inputs , scale , zero_point , axis_i = axis )
76+ if (quant_min , quant_max ) == (0 , 127 ):
77+ quantized = g .op (
78+ "Clip" ,
79+ quantized ,
80+ opset9 .unused (g ),
81+ g .op ("Constant" , value_t = torch .tensor (127 , dtype = torch .uint8 )),
82+ )
83+ return g .op ("DequantizeLinear" , quantized , scale , zero_point , axis_i = axis )
784
85+ @_onnx_symbolic ("aten::_fake_quantize_learnable_per_tensor_affine" )
86+ @symbolic_helper .parse_args ("v" , "v" , "v" , "i" , "i" , "i" )
87+ def _fake_quantize_learnable_per_tensor_affine (
88+ g : jit_utils .GraphContext ,
89+ inputs ,
90+ scale ,
91+ zero_point ,
92+ quant_min = - 128 ,
93+ quant_max = 127 ,
94+ grad_factor = 0 ,
95+ ):
96+ # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
97+ # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
98+ # if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
99+ # raise errors.SymbolicValueError(
100+ # "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). "
101+ # f"Got ({quant_min}, {quant_max})",
102+ # inputs,
103+ # )
104+ if quant_min == 0 :
105+ zero_point = g .op ("Cast" , zero_point , to_i = _C_onnx .TensorProtoDataType .UINT8 )
106+ else :
107+ zero_point = g .op ("Cast" , zero_point , to_i = _C_onnx .TensorProtoDataType .INT8 )
108+ if (
109+ _type_utils .JitScalarType .from_value (scale , _type_utils .JitScalarType .UNDEFINED )
110+ != _type_utils .JitScalarType .FLOAT
111+ ):
112+ scale = g .op ("Cast" , scale , to_i = _C_onnx .TensorProtoDataType .FLOAT )
113+ quantized = g .op ("QuantizeLinear" , inputs , scale , zero_point )
114+ if (quant_min , quant_max ) == (0 , 127 ):
115+ quantized = g .op (
116+ "Clip" ,
117+ quantized ,
118+ opset9 .unused (g ),
119+ g .op ("Constant" , value_t = torch .tensor (127 , dtype = torch .uint8 )),
120+ )
121+ return g .op ("DequantizeLinear" , quantized , scale , zero_point )
8122
9- register_custom_op_symbolic ('::_fake_quantize_learnable_per_tensor_affine' , _fake_quantize_learnable_per_tensor_affine , 11 )
10123
124+ # def _fake_quantize_learnable_per_tensor_affine(g, x, scale, zero_point, quant_min, quant_max, grad_factor):
125+ # return g.op(x, scale, zero_point, quant_min, quant_max)
126+ #
127+ #
128+ # register_custom_op_symbolic('::_fake_quantize_learnable_per_tensor_affine', _fake_quantize_learnable_per_tensor_affine, 11)
129+ #
130+ #
131+ # def fake_quantize_per_channel_affine(g, x, scale, zero_point, ch_axis, quant_min, quant_max):
132+ # return g.op("::FixedPerChannelAffine", x, scale, zero_point, ch_axis, quant_min, quant_max)
133+ #
134+ #
135+ # register_custom_op_symbolic('::fake_quantize_per_channel_affine', fake_quantize_per_channel_affine, 11)
136+ #
137+ #
138+ # def fake_quantize_per_tensor_affine(g, x, scale, zero_point, quant_min, quant_max):
139+ # return g.op("::FixedPerTensorAffine", x, scale, zero_point, quant_min, quant_max)
140+ #
141+ #
142+ # register_custom_op_symbolic('::fake_quantize_per_tensor_affine', fake_quantize_per_tensor_affine, 11)
11143
12- def fake_quantize_per_channel_affine (g , x , scale , zero_point , ch_axis , quant_min , quant_max ):
13- return g .op ("::FixedPerChannelAffine" , x , scale , zero_point , ch_axis , quant_min , quant_max )
14-
15-
16- register_custom_op_symbolic ('::fake_quantize_per_channel_affine' , fake_quantize_per_channel_affine , 11 )
17-
18-
19- def fake_quantize_per_tensor_affine (g , x , scale , zero_point , quant_min , quant_max ):
20- return g .op ("::FixedPerTensorAffine" , x , scale , zero_point , quant_min , quant_max )
21-
22-
23- register_custom_op_symbolic ('::fake_quantize_per_tensor_affine' , fake_quantize_per_tensor_affine , 11 )
0 commit comments