@@ -298,8 +298,7 @@ def quantized_layer_norm_per_tensor(
298298 )
299299
300300
301- @impl (m , "quantized_conv_nchw" )
302- def quantized_conv_nchw (
301+ def quantized_conv (
303302 input_tensor : torch .Tensor ,
304303 weight : torch .Tensor ,
305304 bias : torch .Tensor ,
@@ -374,6 +373,120 @@ def quantized_conv_nchw(
374373 )
375374
376375
376+ @impl (m , "quantized_conv_nchw" )
377+ def quantized_conv_nchw (
378+ input_tensor : torch .Tensor ,
379+ weight : torch .Tensor ,
380+ bias : torch .Tensor ,
381+ stride : tuple [int , int ],
382+ padding : tuple [int , int ],
383+ dilation : tuple [int , int ],
384+ groups : int ,
385+ in_zero_point : int ,
386+ weight_zero_point : torch .Tensor ,
387+ bias_scale : torch .Tensor ,
388+ output_scale : float ,
389+ output_zero_point : int ,
390+ out_multiplier : torch .Tensor ,
391+ out_shift : torch .Tensor ,
392+ ) -> torch .Tensor :
393+ """
394+ Quantized convolution operation.
395+
396+ Args:
397+ - input_tensor (Tensor): The activations tensor
398+ - weight (Tensor): The weight tensor
399+ - bias (Tensor): The bias tensor
400+ - stride (Tuple[int]): The stride of the convolution
401+ - padding (Tuple[int]): The padding of the convolution
402+ - dilation (Tuple[int]): The dilation of the convolution
403+ - groups (int): The number of groups
404+ - in_zero_point (int): The quantized mapping of zero for the input
405+ - weight_zero_point (Tensor): The quantized mapping of zero for the weight
406+ - bias_scale (Tensor): The quantized bias scale
407+ - output_scale (float): The scale of the output
408+ - output_zero_point (int): The zero point of the output
409+ - out_multiplier (Tensor): Unused
410+ - out_shift (Tensor): Unused
411+ """
412+ if not input_tensor .is_contiguous (memory_format = torch .contiguous_format ):
413+ raise ValueError ("Input tensor must be in NCHW format" )
414+ return quantized_conv (
415+ input_tensor ,
416+ weight ,
417+ bias ,
418+ stride ,
419+ padding ,
420+ dilation ,
421+ groups ,
422+ in_zero_point ,
423+ weight_zero_point ,
424+ bias_scale ,
425+ output_scale ,
426+ output_zero_point ,
427+ out_multiplier ,
428+ out_shift ,
429+ )
430+
431+
432+ @impl (m , "quantized_conv_nhwc" )
433+ def quantized_conv_nhwc (
434+ input_tensor : torch .Tensor ,
435+ weight : torch .Tensor ,
436+ bias : torch .Tensor ,
437+ stride : tuple [int , int ],
438+ padding : tuple [int , int ],
439+ dilation : tuple [int , int ],
440+ groups : int ,
441+ in_zero_point : int ,
442+ weight_zero_point : torch .Tensor ,
443+ bias_scale : torch .Tensor ,
444+ output_scale : float ,
445+ output_zero_point : int ,
446+ out_multiplier : torch .Tensor ,
447+ out_shift : torch .Tensor ,
448+ ) -> torch .Tensor :
449+ """
450+ Quantized convolution operation.
451+
452+ Args:
453+ - input_tensor (Tensor): The activations tensor
454+ - weight (Tensor): The weight tensor
455+ - bias (Tensor): The bias tensor
456+ - stride (Tuple[int]): The stride of the convolution
457+ - padding (Tuple[int]): The padding of the convolution
458+ - dilation (Tuple[int]): The dilation of the convolution
459+ - groups (int): The number of groups
460+ - in_zero_point (int): The quantized mapping of zero for the input
461+ - weight_zero_point (Tensor): The quantized mapping of zero for the weight
462+ - bias_scale (Tensor): The quantized bias scale
463+ - output_scale (float): The scale of the output
464+ - output_zero_point (int): The zero point of the output
465+ - out_multiplier (Tensor): Unused
466+ - out_shift (Tensor): Unused
467+ """
468+
469+ if not input_tensor .is_contiguous (memory_format = torch .channels_last ):
470+ raise ValueError ("Input tensor must be in NHWC format" )
471+
472+ return quantized_conv (
473+ input_tensor ,
474+ weight ,
475+ bias ,
476+ stride ,
477+ padding ,
478+ dilation ,
479+ groups ,
480+ in_zero_point ,
481+ weight_zero_point ,
482+ bias_scale ,
483+ output_scale ,
484+ output_zero_point ,
485+ out_multiplier ,
486+ out_shift ,
487+ )
488+
489+
377490@impl (m , "requantize" )
378491def requantize (
379492 input : torch .Tensor ,
0 commit comments