@@ -66,33 +66,18 @@ def get_args_and_kwargs_add(
6666 dequants_inputs : List [fx .Node ],
6767 quant_node : fx .Node ,
6868) -> Tuple [Tuple [ArgsType , ...], Dict [str , ArgsType ]]:
69- X_scale_ = graph_module .graph .call_function (
70- torch .ops .aten .full .default ,
71- ([1 ], dequants_inputs [0 ].args [1 ]),
72- {"dtype" : torch .float },
73- )
74- X_zero_point_ = graph_module .graph .call_function (
75- torch .ops .aten .full .default ,
76- ([1 ], dequants_inputs [0 ].args [2 ]),
77- {"dtype" : torch .int32 },
78- )
79- Y_scale_ = graph_module .graph .call_function (
80- torch .ops .aten .full .default ,
81- ([1 ], dequants_inputs [1 ].args [1 ]),
82- {"dtype" : torch .float },
83- )
84- Y_zero_point_ = graph_module .graph .call_function (
85- torch .ops .aten .full .default ,
86- ([1 ], dequants_inputs [1 ].args [2 ]),
87- {"dtype" : torch .int32 },
88- )
69+ X_scale = dequants_inputs [0 ].args [1 ]
70+
71+ X_zero_point = dequants_inputs [0 ].args [2 ]
72+ Y_scale = dequants_inputs [1 ].args [1 ]
73+ Y_zero_point = dequants_inputs [1 ].args [2 ]
8974 args = (
9075 inputs_inputs [0 ],
91- X_scale_ ,
92- X_zero_point_ ,
76+ X_scale ,
77+ X_zero_point ,
9378 inputs_inputs [1 ],
94- Y_scale_ ,
95- Y_zero_point_ ,
79+ Y_scale ,
80+ Y_zero_point ,
9681 quant_node .args [1 ],
9782 quant_node .args [2 ],
9883 )
@@ -130,31 +115,12 @@ def get_args_and_kwargs_linear(
130115 else :
131116 bias = bias_inputs [0 ]
132117
133- # Create single element tensors for weight_zero_point, out_multiplier, out_shift.
134- # Note that the function expects int32_t, when it would default to int64_t, so
135- # we explicitly require that type.
136- weight_zero_point_ = graph_module .graph .call_function (
137- torch .ops .aten .full .default ,
138- ([1 ], dequants_weights [0 ].args [2 ]),
139- {"dtype" : torch .int32 },
140- )
141- out_multiplier_ = graph_module .graph .call_function (
142- torch .ops .aten .full .default ,
143- ([1 ], out_multiplier [0 ].item ()),
144- {"dtype" : torch .int32 },
145- )
146- out_shift_ = graph_module .graph .call_function (
147- torch .ops .aten .full .default ,
148- ([1 ], out_shift [0 ].item ()),
149- {"dtype" : torch .int32 },
150- )
151-
152118 args = tuple (inputs_inputs + weights_inputs + [bias ])
153119 kwargs = {
154120 "src_zero_point" : dequants_inputs [0 ].args [2 ],
155- "weight_zero_point" : weight_zero_point_ ,
156- "out_multiplier" : out_multiplier_ ,
157- "out_shift" : out_shift_ ,
121+ "weight_zero_point" : dequants_weights [ 0 ]. args [ 2 ] ,
122+ "out_multiplier" : out_multiplier [ 0 ]. item () ,
123+ "out_shift" : out_shift [ 0 ]. item () ,
158124 "out_zero_point" : quant_node .args [2 ],
159125 "offset" : None ,
160126 }
@@ -179,22 +145,8 @@ def get_args_and_kwargs_layer_norm(
179145 ), "per-channel quantization is not supported for layer norm, both scale and zero_point should be scalars"
180146
181147 # Make the scale and zero_point tensors
182- scale_tensor = graph_module .graph .call_function (
183- torch .ops .aten .full .default ,
184- (
185- [1 ],
186- dequants_inputs [0 ].args [1 ],
187- ),
188- {"dtype" : torch .float32 },
189- )
190- zero_point_tensor = graph_module .graph .call_function (
191- torch .ops .aten .full .default ,
192- (
193- [1 ],
194- dequants_inputs [0 ].args [2 ],
195- ),
196- {"dtype" : torch .int32 },
197- )
148+ scale = dequants_inputs [0 ].args [1 ]
149+ zero_point = dequants_inputs [0 ].args [2 ]
198150
199151 weight = other_inputs [1 ] if len (other_inputs ) > 1 else None
200152
@@ -221,7 +173,7 @@ def get_args_and_kwargs_layer_norm(
221173 )
222174
223175 # Make the args and kwargs for the replacement op
224- args = tuple (inputs_inputs + [scale_tensor ] + [ zero_point_tensor ])
176+ args = tuple (inputs_inputs + [scale , zero_point ])
225177 kwargs = {
226178 "normalized_shape" : other_inputs [0 ],
227179 "weight" : weight ,
@@ -309,31 +261,6 @@ def get_args_and_kwargs_conv(
309261
310262 (out_multiplier , out_shift ) = quantize_tensor_multiplier (requantize_scale_t )
311263
312- out_multiplier_ = graph_module .graph .call_function (
313- torch .ops .aten .full .default ,
314- ([1 ], out_multiplier [0 ].item ()),
315- {"dtype" : torch .int32 },
316- )
317- out_shift_ = graph_module .graph .call_function (
318- torch .ops .aten .full .default ,
319- ([1 ], out_shift [0 ].item ()),
320- {"dtype" : torch .int32 },
321- )
322-
323- # Create a single element tensor for the weight zero point
324- weight_zero_point_tensor = graph_module .graph .call_function (
325- torch .ops .aten .full .default ,
326- ([1 ], weight_zero_point ),
327- {"dtype" : torch .int32 },
328- )
329-
330- # Create a single element tensor for the bias scale
331- bias_scale_tensor = graph_module .graph .call_function (
332- torch .ops .aten .full .default ,
333- ([1 ], bias_scale ),
334- {"dtype" : torch .float32 },
335- )
336-
337264 # Make the args and kwargs for the replacement op
338265 args = tuple (inputs_inputs + weights_inputs + [bias ])
339266 kwargs = {
@@ -342,12 +269,12 @@ def get_args_and_kwargs_conv(
342269 "dilation" : dilation ,
343270 "groups" : groups ,
344271 "input_zero_point" : dequants_inputs [0 ].args [2 ],
345- "weight_zero_point" : weight_zero_point_tensor ,
346- "bias_scale" : bias_scale_tensor ,
272+ "weight_zero_point" : weight_zero_point ,
273+ "bias_scale" : bias_scale ,
347274 "out_scale" : quant_node .args [1 ],
348275 "out_zero_point" : quant_node .args [2 ],
349- "out_multiplier" : out_multiplier_ ,
350- "out_shift" : out_shift_ ,
276+ "out_multiplier" : out_multiplier [ 0 ]. item () ,
277+ "out_shift" : out_shift [ 0 ]. item () ,
351278 }
352279 return args , kwargs
353280
@@ -368,27 +295,11 @@ def get_args_and_kwargs_relu(
368295 # Make the args and kwargs for the replacement op
369296 args = tuple (inputs_inputs )
370297
371- X_zero_point = graph_module .graph .call_function (
372- torch .ops .aten .full .default ,
373- ([1 ], dequants_inputs [0 ].args [2 ]),
374- {"dtype" : torch .int32 },
375- )
376- out_multiplier_ = graph_module .graph .call_function (
377- torch .ops .aten .full .default ,
378- ([1 ], out_multiplier [0 ].item ()),
379- {"dtype" : torch .int32 },
380- )
381- out_shift_ = graph_module .graph .call_function (
382- torch .ops .aten .full .default ,
383- ([1 ], out_shift [0 ].item ()),
384- {"dtype" : torch .int32 },
385- )
386-
387298 kwargs = {
388- "X_zero_point" : X_zero_point ,
299+ "X_zero_point" : dequants_inputs [ 0 ]. args [ 2 ] ,
389300 "out_zero_point" : quant_node .args [2 ],
390- "out_multiplier" : out_multiplier_ ,
391- "out_shift" : out_shift_ ,
301+ "out_multiplier" : out_multiplier [ 0 ]. item () ,
302+ "out_shift" : out_shift [ 0 ]. item () ,
392303 }
393304 return args , kwargs
394305
@@ -436,48 +347,20 @@ def get_args_and_kwargs_softmax(
436347 {"dtype" : torch .int32 },
437348 )
438349 # Make the scale and zero_point tensors
439- in_scale_tensor = graph_module .graph .call_function (
440- torch .ops .aten .full .default ,
441- (
442- [1 ],
443- dequants_inputs [0 ].args [1 ],
444- ),
445- {"dtype" : torch .float32 },
446- )
447- in_zero_point_tensor = graph_module .graph .call_function (
448- torch .ops .aten .full .default ,
449- (
450- [1 ],
451- dequants_inputs [0 ].args [2 ],
452- ),
453- {"dtype" : torch .int32 },
454- )
455- out_scale_tensor = graph_module .graph .call_function (
456- torch .ops .aten .full .default ,
457- (
458- [1 ],
459- quant_node .args [1 ],
460- ),
461- {"dtype" : torch .float32 },
462- )
463- out_zero_point_tensor = graph_module .graph .call_function (
464- torch .ops .aten .full .default ,
465- (
466- [1 ],
467- quant_node .args [2 ],
468- ),
469- {"dtype" : torch .int32 },
470- )
350+ in_scale = dequants_inputs [0 ].args [1 ]
351+ in_zero_point = dequants_inputs [0 ].args [2 ]
352+ out_scale = quant_node .args [1 ]
353+ out_zero_point = quant_node .args [2 ]
471354
472355 # Make the args and kwargs for the replacement op
473356 args = (
474357 inputs_inputs [0 ],
475358 mask_tensor ,
476359 op_node .args [1 ],
477- in_scale_tensor ,
478- in_zero_point_tensor ,
479- out_scale_tensor ,
480- out_zero_point_tensor ,
360+ in_scale ,
361+ in_zero_point ,
362+ out_scale ,
363+ out_zero_point ,
481364 )
482365 kwargs = {}
483366
0 commit comments