@@ -298,7 +298,7 @@ def quantized_layer_norm_per_tensor(
298298 )
299299
300300
301- def quantized_conv (
301+ def quantized_conv_per_tensor (
302302 input_tensor : torch .Tensor ,
303303 weight : torch .Tensor ,
304304 bias : torch .Tensor ,
@@ -307,12 +307,12 @@ def quantized_conv(
307307 dilation : tuple [int , int ],
308308 groups : int ,
309309 in_zero_point : int ,
310- weight_zero_point : torch . Tensor ,
311- bias_scale : torch . Tensor ,
310+ weight_zero_point : int ,
311+ bias_scale : float ,
312312 output_scale : float ,
313313 output_zero_point : int ,
314- out_multiplier : torch . Tensor ,
315- out_shift : torch . Tensor ,
314+ out_multiplier : int ,
315+ out_shift : int ,
316316) -> torch .Tensor :
317317 """
318318 Quantized convolution operation.
@@ -326,19 +326,13 @@ def quantized_conv(
326326 - dilation (Tuple[int]): The dilation of the convolution
327327 - groups (int): The number of groups
328328 - in_zero_point (int): The quantized mapping of zero for the input
329- - weight_zero_point (Tensor ): The quantized mapping of zero for the weight
330- - bias_scale (Tensor ): The quantized bias scale
329+ - weight_zero_point (int ): The quantized mapping of zero for the weight
330+ - bias_scale (float ): The quantized bias scale
331331 - output_scale (float): The scale of the output
332332 - output_zero_point (int): The zero point of the output
333- - out_multiplier (Tensor ): Unused
334- - out_shift (Tensor ): Unused
333+ - out_multiplier (int ): Unused
334+ - out_shift (int ): Unused
335335 """
336- if weight_zero_point .view (- 1 ).shape != (1 ,):
337- raise ValueError ("Weight zero point must be a scalar" )
338-
339- if bias_scale .view (- 1 ).shape != (1 ,):
340- raise ValueError ("Bias scale must be a scalar" )
341-
342336 if len (input_tensor .shape ) == 3 :
343337 float_out = torch .nn .functional .conv1d (
344338 (input_tensor - in_zero_point ).float (),
@@ -373,8 +367,8 @@ def quantized_conv(
373367 )
374368
375369
376- @impl (m , "quantized_conv_nchw " )
377- def quantized_conv_nchw (
370+ @impl (m , "quantized_conv_nchw_per_tensor " )
371+ def quantized_conv_nchw_per_tensor (
378372 input_tensor : torch .Tensor ,
379373 weight : torch .Tensor ,
380374 bias : torch .Tensor ,
@@ -383,12 +377,12 @@ def quantized_conv_nchw(
383377 dilation : tuple [int , int ],
384378 groups : int ,
385379 in_zero_point : int ,
386- weight_zero_point : torch . Tensor ,
387- bias_scale : torch . Tensor ,
380+ weight_zero_point : int ,
381+ bias_scale : float ,
388382 output_scale : float ,
389383 output_zero_point : int ,
390- out_multiplier : torch . Tensor ,
391- out_shift : torch . Tensor ,
384+ out_multiplier : int ,
385+ out_shift : int ,
392386) -> torch .Tensor :
393387 """
394388 Quantized convolution operation.
@@ -402,16 +396,16 @@ def quantized_conv_nchw(
402396 - dilation (Tuple[int]): The dilation of the convolution
403397 - groups (int): The number of groups
404398 - in_zero_point (int): The quantized mapping of zero for the input
405- - weight_zero_point (Tensor ): The quantized mapping of zero for the weight
406- - bias_scale (Tensor ): The quantized bias scale
399+ - weight_zero_point (int ): The quantized mapping of zero for the weight
400+ - bias_scale (float ): The quantized bias scale
407401 - output_scale (float): The scale of the output
408402 - output_zero_point (int): The zero point of the output
409- - out_multiplier (Tensor ): Unused
410- - out_shift (Tensor ): Unused
403+ - out_multiplier (int ): Unused
404+ - out_shift (int ): Unused
411405 """
412406 if not input_tensor .is_contiguous (memory_format = torch .contiguous_format ):
413407 raise ValueError ("Input tensor must be in NCHW format" )
414- return quantized_conv (
408+ return quantized_conv_per_tensor (
415409 input_tensor ,
416410 weight ,
417411 bias ,
@@ -429,8 +423,8 @@ def quantized_conv_nchw(
429423 )
430424
431425
432- @impl (m , "quantized_conv_nhwc " )
433- def quantized_conv_nhwc (
426+ @impl (m , "quantized_conv_nhwc_per_tensor " )
427+ def quantized_conv_nhwc_per_tensor (
434428 input_tensor : torch .Tensor ,
435429 weight : torch .Tensor ,
436430 bias : torch .Tensor ,
@@ -439,12 +433,12 @@ def quantized_conv_nhwc(
439433 dilation : tuple [int , int ],
440434 groups : int ,
441435 in_zero_point : int ,
442- weight_zero_point : torch . Tensor ,
443- bias_scale : torch . Tensor ,
436+ weight_zero_point : int ,
437+ bias_scale : float ,
444438 output_scale : float ,
445439 output_zero_point : int ,
446- out_multiplier : torch . Tensor ,
447- out_shift : torch . Tensor ,
440+ out_multiplier : int ,
441+ out_shift : int ,
448442) -> torch .Tensor :
449443 """
450444 Quantized convolution operation.
@@ -458,18 +452,18 @@ def quantized_conv_nhwc(
458452 - dilation (Tuple[int]): The dilation of the convolution
459453 - groups (int): The number of groups
460454 - in_zero_point (int): The quantized mapping of zero for the input
461- - weight_zero_point (Tensor ): The quantized mapping of zero for the weight
462- - bias_scale (Tensor ): The quantized bias scale
455+ - weight_zero_point (int ): The quantized mapping of zero for the weight
456+ - bias_scale (float ): The quantized bias scale
463457 - output_scale (float): The scale of the output
464458 - output_zero_point (int): The zero point of the output
465- - out_multiplier (Tensor ): Unused
466- - out_shift (Tensor ): Unused
459+ - out_multiplier (int ): Unused
460+ - out_shift (int ): Unused
467461 """
468462
469463 if not input_tensor .is_contiguous (memory_format = torch .channels_last ):
470464 raise ValueError ("Input tensor must be in NHWC format" )
471465
472- return quantized_conv (
466+ return quantized_conv_per_tensor (
473467 input_tensor ,
474468 weight ,
475469 bias ,
0 commit comments