@@ -65,33 +65,18 @@ def get_args_and_kwargs_add(
6565 dequants_inputs : List [fx .Node ],
6666 quant_node : fx .Node ,
6767) -> Tuple [Tuple [ArgsType , ...], Dict [str , ArgsType ]]:
68- X_scale_ = graph_module .graph .call_function (
69- torch .ops .aten .full .default ,
70- ([1 ], dequants_inputs [0 ].args [1 ]),
71- {"dtype" : torch .float },
72- )
73- X_zero_point_ = graph_module .graph .call_function (
74- torch .ops .aten .full .default ,
75- ([1 ], dequants_inputs [0 ].args [2 ]),
76- {"dtype" : torch .int32 },
77- )
78- Y_scale_ = graph_module .graph .call_function (
79- torch .ops .aten .full .default ,
80- ([1 ], dequants_inputs [1 ].args [1 ]),
81- {"dtype" : torch .float },
82- )
83- Y_zero_point_ = graph_module .graph .call_function (
84- torch .ops .aten .full .default ,
85- ([1 ], dequants_inputs [1 ].args [2 ]),
86- {"dtype" : torch .int32 },
87- )
68+ X_scale = dequants_inputs [0 ].args [1 ]
69+
70+ X_zero_point = dequants_inputs [0 ].args [2 ]
71+ Y_scale = dequants_inputs [1 ].args [1 ]
72+ Y_zero_point = dequants_inputs [1 ].args [2 ]
8873 args = (
8974 inputs_inputs [0 ],
90- X_scale_ ,
91- X_zero_point_ ,
75+ X_scale ,
76+ X_zero_point ,
9277 inputs_inputs [1 ],
93- Y_scale_ ,
94- Y_zero_point_ ,
78+ Y_scale ,
79+ Y_zero_point ,
9580 quant_node .args [1 ],
9681 quant_node .args [2 ],
9782 )
@@ -129,31 +114,12 @@ def get_args_and_kwargs_linear(
129114 else :
130115 bias = bias_inputs [0 ]
131116
132- # Create single element tensors for weight_zero_point, out_multiplier, out_shift.
133- # Note that the function expects int32_t, when it would default to int64_t, so
134- # we explicitly require that type.
135- weight_zero_point_ = graph_module .graph .call_function (
136- torch .ops .aten .full .default ,
137- ([1 ], dequants_weights [0 ].args [2 ]),
138- {"dtype" : torch .int32 },
139- )
140- out_multiplier_ = graph_module .graph .call_function (
141- torch .ops .aten .full .default ,
142- ([1 ], out_multiplier [0 ].item ()),
143- {"dtype" : torch .int32 },
144- )
145- out_shift_ = graph_module .graph .call_function (
146- torch .ops .aten .full .default ,
147- ([1 ], out_shift [0 ].item ()),
148- {"dtype" : torch .int32 },
149- )
150-
151117 args = tuple (inputs_inputs + weights_inputs + [bias ])
152118 kwargs = {
153119 "src_zero_point" : dequants_inputs [0 ].args [2 ],
154- "weight_zero_point" : weight_zero_point_ ,
155- "out_multiplier" : out_multiplier_ ,
156- "out_shift" : out_shift_ ,
120+ "weight_zero_point" : dequants_weights [ 0 ]. args [ 2 ] ,
121+ "out_multiplier" : out_multiplier [ 0 ]. item () ,
122+ "out_shift" : out_shift [ 0 ]. item () ,
157123 "out_zero_point" : quant_node .args [2 ],
158124 "offset" : None ,
159125 }
@@ -178,22 +144,8 @@ def get_args_and_kwargs_layer_norm(
178144 ), "per-channel quantization is not supported for layer norm, both scale and zero_point should be scalars"
179145
180146 # Make the scale and zero_point tensors
181- scale_tensor = graph_module .graph .call_function (
182- torch .ops .aten .full .default ,
183- (
184- [1 ],
185- dequants_inputs [0 ].args [1 ],
186- ),
187- {"dtype" : torch .float32 },
188- )
189- zero_point_tensor = graph_module .graph .call_function (
190- torch .ops .aten .full .default ,
191- (
192- [1 ],
193- dequants_inputs [0 ].args [2 ],
194- ),
195- {"dtype" : torch .int32 },
196- )
147+ scale = dequants_inputs [0 ].args [1 ]
148+ zero_point = dequants_inputs [0 ].args [2 ]
197149
198150 weight = other_inputs [1 ] if len (other_inputs ) > 1 else None
199151
@@ -220,7 +172,7 @@ def get_args_and_kwargs_layer_norm(
220172 )
221173
222174 # Make the args and kwargs for the replacement op
223- args = tuple (inputs_inputs + [scale_tensor ] + [ zero_point_tensor ])
175+ args = tuple (inputs_inputs + [scale , zero_point ])
224176 kwargs = {
225177 "normalized_shape" : other_inputs [0 ],
226178 "weight" : weight ,
@@ -308,31 +260,6 @@ def get_args_and_kwargs_conv(
308260
309261 (out_multiplier , out_shift ) = quantize_tensor_multiplier (requantize_scale_t )
310262
311- out_multiplier_ = graph_module .graph .call_function (
312- torch .ops .aten .full .default ,
313- ([1 ], out_multiplier [0 ].item ()),
314- {"dtype" : torch .int32 },
315- )
316- out_shift_ = graph_module .graph .call_function (
317- torch .ops .aten .full .default ,
318- ([1 ], out_shift [0 ].item ()),
319- {"dtype" : torch .int32 },
320- )
321-
322- # Create a single element tensor for the weight zero point
323- weight_zero_point_tensor = graph_module .graph .call_function (
324- torch .ops .aten .full .default ,
325- ([1 ], weight_zero_point ),
326- {"dtype" : torch .int32 },
327- )
328-
329- # Create a single element tensor for the bias scale
330- bias_scale_tensor = graph_module .graph .call_function (
331- torch .ops .aten .full .default ,
332- ([1 ], bias_scale ),
333- {"dtype" : torch .float32 },
334- )
335-
336263 # Make the args and kwargs for the replacement op
337264 args = tuple (inputs_inputs + weights_inputs + [bias ])
338265 kwargs = {
@@ -341,12 +268,12 @@ def get_args_and_kwargs_conv(
341268 "dilation" : dilation ,
342269 "groups" : groups ,
343270 "input_zero_point" : dequants_inputs [0 ].args [2 ],
344- "weight_zero_point" : weight_zero_point_tensor ,
345- "bias_scale" : bias_scale_tensor ,
271+ "weight_zero_point" : weight_zero_point ,
272+ "bias_scale" : bias_scale ,
346273 "out_scale" : quant_node .args [1 ],
347274 "out_zero_point" : quant_node .args [2 ],
348- "out_multiplier" : out_multiplier_ ,
349- "out_shift" : out_shift_ ,
275+ "out_multiplier" : out_multiplier [ 0 ]. item () ,
276+ "out_shift" : out_shift [ 0 ]. item () ,
350277 }
351278 return args , kwargs
352279
@@ -367,27 +294,11 @@ def get_args_and_kwargs_relu(
367294 # Make the args and kwargs for the replacement op
368295 args = tuple (inputs_inputs )
369296
370- X_zero_point = graph_module .graph .call_function (
371- torch .ops .aten .full .default ,
372- ([1 ], dequants_inputs [0 ].args [2 ]),
373- {"dtype" : torch .int32 },
374- )
375- out_multiplier_ = graph_module .graph .call_function (
376- torch .ops .aten .full .default ,
377- ([1 ], out_multiplier [0 ].item ()),
378- {"dtype" : torch .int32 },
379- )
380- out_shift_ = graph_module .graph .call_function (
381- torch .ops .aten .full .default ,
382- ([1 ], out_shift [0 ].item ()),
383- {"dtype" : torch .int32 },
384- )
385-
386297 kwargs = {
387- "X_zero_point" : X_zero_point ,
298+ "X_zero_point" : dequants_inputs [ 0 ]. args [ 2 ] ,
388299 "out_zero_point" : quant_node .args [2 ],
389- "out_multiplier" : out_multiplier_ ,
390- "out_shift" : out_shift_ ,
300+ "out_multiplier" : out_multiplier [ 0 ]. item () ,
301+ "out_shift" : out_shift [ 0 ]. item () ,
391302 }
392303 return args , kwargs
393304
@@ -435,48 +346,20 @@ def get_args_and_kwargs_softmax(
435346 {"dtype" : torch .int32 },
436347 )
437348 # Make the scale and zero_point tensors
438- in_scale_tensor = graph_module .graph .call_function (
439- torch .ops .aten .full .default ,
440- (
441- [1 ],
442- dequants_inputs [0 ].args [1 ],
443- ),
444- {"dtype" : torch .float32 },
445- )
446- in_zero_point_tensor = graph_module .graph .call_function (
447- torch .ops .aten .full .default ,
448- (
449- [1 ],
450- dequants_inputs [0 ].args [2 ],
451- ),
452- {"dtype" : torch .int32 },
453- )
454- out_scale_tensor = graph_module .graph .call_function (
455- torch .ops .aten .full .default ,
456- (
457- [1 ],
458- quant_node .args [1 ],
459- ),
460- {"dtype" : torch .float32 },
461- )
462- out_zero_point_tensor = graph_module .graph .call_function (
463- torch .ops .aten .full .default ,
464- (
465- [1 ],
466- quant_node .args [2 ],
467- ),
468- {"dtype" : torch .int32 },
469- )
349+ in_scale = dequants_inputs [0 ].args [1 ]
350+ in_zero_point = dequants_inputs [0 ].args [2 ]
351+ out_scale = quant_node .args [1 ]
352+ out_zero_point = quant_node .args [2 ]
470353
471354 # Make the args and kwargs for the replacement op
472355 args = (
473356 inputs_inputs [0 ],
474357 mask_tensor ,
475358 op_node .args [1 ],
476- in_scale_tensor ,
477- in_zero_point_tensor ,
478- out_scale_tensor ,
479- out_zero_point_tensor ,
359+ in_scale ,
360+ in_zero_point ,
361+ out_scale ,
362+ out_zero_point ,
480363 )
481364 kwargs = {}
482365
0 commit comments