@@ -306,31 +306,6 @@ def get_args_and_kwargs_conv(
306
306
307
307
(out_multiplier , out_shift ) = quantize_tensor_multiplier (requantize_scale_t )
308
308
309
- out_multiplier_ = graph_module .graph .call_function (
310
- torch .ops .aten .full .default ,
311
- ([1 ], out_multiplier [0 ].item ()),
312
- {"dtype" : torch .int32 },
313
- )
314
- out_shift_ = graph_module .graph .call_function (
315
- torch .ops .aten .full .default ,
316
- ([1 ], out_shift [0 ].item ()),
317
- {"dtype" : torch .int32 },
318
- )
319
-
320
- # Create a single element tensor for the weight zero point
321
- weight_zero_point_tensor = graph_module .graph .call_function (
322
- torch .ops .aten .full .default ,
323
- ([1 ], weight_zero_point ),
324
- {"dtype" : torch .int32 },
325
- )
326
-
327
- # Create a single element tensor for the bias scale
328
- bias_scale_tensor = graph_module .graph .call_function (
329
- torch .ops .aten .full .default ,
330
- ([1 ], bias_scale ),
331
- {"dtype" : torch .float32 },
332
- )
333
-
334
309
# Make the args and kwargs for the replacement op
335
310
args = tuple (inputs_inputs + weights_inputs + [bias ])
336
311
kwargs = {
@@ -339,12 +314,12 @@ def get_args_and_kwargs_conv(
339
314
"dilation" : dilation ,
340
315
"groups" : groups ,
341
316
"input_zero_point" : dequants_inputs [0 ].args [2 ],
342
- "weight_zero_point" : weight_zero_point_tensor ,
343
- "bias_scale" : bias_scale_tensor ,
317
+ "weight_zero_point" : weight_zero_point ,
318
+ "bias_scale" : bias_scale ,
344
319
"out_scale" : quant_node .args [1 ],
345
320
"out_zero_point" : quant_node .args [2 ],
346
- "out_multiplier" : out_multiplier_ ,
347
- "out_shift" : out_shift_ ,
321
+ "out_multiplier" : out_multiplier [ 0 ]. item () ,
322
+ "out_shift" : out_shift [ 0 ]. item () ,
348
323
}
349
324
return args , kwargs
350
325
@@ -365,27 +340,11 @@ def get_args_and_kwargs_relu(
365
340
# Make the args and kwargs for the replacement op
366
341
args = tuple (inputs_inputs )
367
342
368
- X_zero_point = graph_module .graph .call_function (
369
- torch .ops .aten .full .default ,
370
- ([1 ], dequants_inputs [0 ].args [2 ]),
371
- {"dtype" : torch .int32 },
372
- )
373
- out_multiplier_ = graph_module .graph .call_function (
374
- torch .ops .aten .full .default ,
375
- ([1 ], out_multiplier [0 ].item ()),
376
- {"dtype" : torch .int32 },
377
- )
378
- out_shift_ = graph_module .graph .call_function (
379
- torch .ops .aten .full .default ,
380
- ([1 ], out_shift [0 ].item ()),
381
- {"dtype" : torch .int32 },
382
- )
383
-
384
343
kwargs = {
385
- "X_zero_point" : X_zero_point ,
344
+ "X_zero_point" : dequants_inputs [ 0 ]. args [ 2 ] ,
386
345
"out_zero_point" : quant_node .args [2 ],
387
- "out_multiplier" : out_multiplier_ ,
388
- "out_shift" : out_shift_ ,
346
+ "out_multiplier" : out_multiplier [ 0 ]. item () ,
347
+ "out_shift" : out_shift [ 0 ]. item () ,
389
348
}
390
349
return args , kwargs
391
350
0 commit comments