@@ -296,7 +296,7 @@ def quantized_layer_norm_per_tensor(
296296 )
297297
298298
299- def quantized_conv (
299+ def quantized_conv_per_tensor (
300300 input_tensor : torch .Tensor ,
301301 weight : torch .Tensor ,
302302 bias : torch .Tensor ,
@@ -305,12 +305,12 @@ def quantized_conv(
305305 dilation : tuple [int , int ],
306306 groups : int ,
307307 in_zero_point : int ,
308- weight_zero_point : torch . Tensor ,
309- bias_scale : torch . Tensor ,
308+ weight_zero_point : int ,
309+ bias_scale : float ,
310310 output_scale : float ,
311311 output_zero_point : int ,
312- out_multiplier : torch . Tensor ,
313- out_shift : torch . Tensor ,
312+ out_multiplier : int ,
313+ out_shift : int ,
314314) -> torch .Tensor :
315315 """
316316 Quantized convolution operation.
@@ -324,19 +324,13 @@ def quantized_conv(
324324 - dilation (Tuple[int]): The dilation of the convolution
325325 - groups (int): The number of groups
326326 - in_zero_point (int): The quantized mapping of zero for the input
327- - weight_zero_point (Tensor ): The quantized mapping of zero for the weight
328- - bias_scale (Tensor ): The quantized bias scale
327+ - weight_zero_point (int ): The quantized mapping of zero for the weight
328+ - bias_scale (float ): The quantized bias scale
329329 - output_scale (float): The scale of the output
330330 - output_zero_point (int): The zero point of the output
331- - out_multiplier (Tensor ): Unused
332- - out_shift (Tensor ): Unused
331+ - out_multiplier (int ): Unused
332+ - out_shift (int ): Unused
333333 """
334- if weight_zero_point .view (- 1 ).shape != (1 ,):
335- raise ValueError ("Weight zero point must be a scalar" )
336-
337- if bias_scale .view (- 1 ).shape != (1 ,):
338- raise ValueError ("Bias scale must be a scalar" )
339-
340334 if len (input_tensor .shape ) == 3 :
341335 float_out = torch .nn .functional .conv1d (
342336 (input_tensor - in_zero_point ).float (),
@@ -371,8 +365,8 @@ def quantized_conv(
371365 )
372366
373367
374- @impl (m , "quantized_conv_nchw " )
375- def quantized_conv_nchw (
368+ @impl (m , "quantized_conv_nchw_per_tensor " )
369+ def quantized_conv_nchw_per_tensor (
376370 input_tensor : torch .Tensor ,
377371 weight : torch .Tensor ,
378372 bias : torch .Tensor ,
@@ -381,12 +375,12 @@ def quantized_conv_nchw(
381375 dilation : tuple [int , int ],
382376 groups : int ,
383377 in_zero_point : int ,
384- weight_zero_point : torch . Tensor ,
385- bias_scale : torch . Tensor ,
378+ weight_zero_point : int ,
379+ bias_scale : float ,
386380 output_scale : float ,
387381 output_zero_point : int ,
388- out_multiplier : torch . Tensor ,
389- out_shift : torch . Tensor ,
382+ out_multiplier : int ,
383+ out_shift : int ,
390384) -> torch .Tensor :
391385 """
392386 Quantized convolution operation.
@@ -400,16 +394,16 @@ def quantized_conv_nchw(
400394 - dilation (Tuple[int]): The dilation of the convolution
401395 - groups (int): The number of groups
402396 - in_zero_point (int): The quantized mapping of zero for the input
403- - weight_zero_point (Tensor ): The quantized mapping of zero for the weight
404- - bias_scale (Tensor ): The quantized bias scale
397+ - weight_zero_point (int ): The quantized mapping of zero for the weight
398+ - bias_scale (float ): The quantized bias scale
405399 - output_scale (float): The scale of the output
406400 - output_zero_point (int): The zero point of the output
407- - out_multiplier (Tensor ): Unused
408- - out_shift (Tensor ): Unused
401+ - out_multiplier (int ): Unused
402+ - out_shift (int ): Unused
409403 """
410404 if not input_tensor .is_contiguous (memory_format = torch .contiguous_format ):
411405 raise ValueError ("Input tensor must be in NCHW format" )
412- return quantized_conv (
406+ return quantized_conv_per_tensor (
413407 input_tensor ,
414408 weight ,
415409 bias ,
@@ -427,8 +421,8 @@ def quantized_conv_nchw(
427421 )
428422
429423
430- @impl (m , "quantized_conv_nhwc " )
431- def quantized_conv_nhwc (
424+ @impl (m , "quantized_conv_nhwc_per_tensor " )
425+ def quantized_conv_nhwc_per_tensor (
432426 input_tensor : torch .Tensor ,
433427 weight : torch .Tensor ,
434428 bias : torch .Tensor ,
@@ -437,12 +431,12 @@ def quantized_conv_nhwc(
437431 dilation : tuple [int , int ],
438432 groups : int ,
439433 in_zero_point : int ,
440- weight_zero_point : torch . Tensor ,
441- bias_scale : torch . Tensor ,
434+ weight_zero_point : int ,
435+ bias_scale : float ,
442436 output_scale : float ,
443437 output_zero_point : int ,
444- out_multiplier : torch . Tensor ,
445- out_shift : torch . Tensor ,
438+ out_multiplier : int ,
439+ out_shift : int ,
446440) -> torch .Tensor :
447441 """
448442 Quantized convolution operation.
@@ -456,18 +450,18 @@ def quantized_conv_nhwc(
456450 - dilation (Tuple[int]): The dilation of the convolution
457451 - groups (int): The number of groups
458452 - in_zero_point (int): The quantized mapping of zero for the input
459- - weight_zero_point (Tensor ): The quantized mapping of zero for the weight
460- - bias_scale (Tensor ): The quantized bias scale
453+ - weight_zero_point (int ): The quantized mapping of zero for the weight
454+ - bias_scale (float ): The quantized bias scale
461455 - output_scale (float): The scale of the output
462456 - output_zero_point (int): The zero point of the output
463- - out_multiplier (Tensor ): Unused
464- - out_shift (Tensor ): Unused
457+ - out_multiplier (int ): Unused
458+ - out_shift (int ): Unused
465459 """
466460
467461 if not input_tensor .is_contiguous (memory_format = torch .channels_last ):
468462 raise ValueError ("Input tensor must be in NHWC format" )
469463
470- return quantized_conv (
464+ return quantized_conv_per_tensor (
471465 input_tensor ,
472466 weight ,
473467 bias ,
0 commit comments