8585)
8686
8787lib .define (
88- "quantized_conv (Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False ) -> (Tensor Z)"
88+ "quantized_conv_nhwc (Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Z)"
8989)
9090lib .define (
91- "quantized_conv .out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False , *, Tensor(a!) out) -> Tensor(a!)"
91+ "quantized_conv_nhwc .out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)"
9292)
9393lib .define (
94- "quantized_conv .per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False ) -> (Tensor Z)"
94+ "quantized_conv_nhwc .per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
9595)
9696lib .define (
97- "quantized_conv.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
97+ "quantized_conv_nhwc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
98+ )
99+ lib .define (
100+ "quantized_conv_nchw(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Z)"
101+ )
102+ lib .define (
103+ "quantized_conv_nchw.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)"
104+ )
105+ lib .define (
106+ "quantized_conv_nchw.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
107+ )
108+ lib .define (
109+ "quantized_conv_nchw.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
98110)
99-
100111lib .define (
101112 "quantized_matmul(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False) -> (Tensor Z)"
102113)
@@ -532,8 +543,8 @@ def quantized_linear_asym8uxasym8u_asym8u_per_tensor_meta(
532543 return src .new_empty (out_size , dtype = src .dtype )
533544
534545
535- @register_fake ("cadence::quantized_conv " )
536- def quantized_conv_meta (
546+ @register_fake ("cadence::quantized_conv_nhwc " )
547+ def quantized_conv_nhwc_meta (
537548 input : torch .Tensor ,
538549 weight : torch .Tensor ,
539550 bias : torch .Tensor ,
@@ -548,12 +559,8 @@ def quantized_conv_meta(
548559 output_zero_point : int ,
549560 out_multiplier : torch .Tensor ,
550561 out_shift : torch .Tensor ,
551- channel_last : bool = False ,
552562) -> torch .Tensor :
553- if channel_last :
554- out_channels , * kernel_size , _ = weight .shape
555- else :
556- out_channels , _ , * kernel_size = weight .shape
563+ out_channels , * kernel_size , _ = weight .shape
557564
558565 in_size = input .shape
559566 # Assert that the input tensor has at least 3 dimensions, and at most 6
@@ -569,19 +576,63 @@ def quantized_conv_meta(
569576 padding [1 ],
570577 dilation [1 ],
571578 kernel_size [0 ],
572- channel_last ,
579+ True ,
573580 )
574581 if len (in_size ) == 3
575582 else get_conv2d_output_size (
576- in_size , out_channels , stride , padding , dilation , kernel_size , channel_last
583+ in_size , out_channels , stride , padding , dilation , kernel_size , True
577584 )
578585 )
579586
580587 return input .new_empty (output_size , dtype = input .dtype )
581588
582589
583- @register_fake ("cadence::quantized_conv.per_tensor" )
584- def quantized_conv_per_tensor_meta (
590+ @register_fake ("cadence::quantized_conv_nchw" )
591+ def quantized_conv_nchw_meta (
592+ input : torch .Tensor ,
593+ weight : torch .Tensor ,
594+ bias : torch .Tensor ,
595+ stride : Tuple [int ],
596+ padding : Tuple [int ],
597+ dilation : Tuple [int ],
598+ groups : int ,
599+ in_zero_point : int ,
600+ weight_zero_point : torch .Tensor ,
601+ bias_scale : torch .Tensor ,
602+ output_scale : float ,
603+ output_zero_point : int ,
604+ out_multiplier : torch .Tensor ,
605+ out_shift : torch .Tensor ,
606+ ) -> torch .Tensor :
607+ out_channels , _ , * kernel_size = weight .shape
608+
609+ in_size = input .shape
610+ # Assert that the input tensor has at least 3 dimensions, and at most 6
611+ assert len (in_size ) > 2
612+ assert len (in_size ) < 6
613+
614+ # Compute the output tensor size
615+ output_size = (
616+ get_conv1d_output_size (
617+ in_size ,
618+ out_channels ,
619+ stride [1 ],
620+ padding [1 ],
621+ dilation [1 ],
622+ kernel_size [0 ],
623+ False ,
624+ )
625+ if len (in_size ) == 3
626+ else get_conv2d_output_size (
627+ in_size , out_channels , stride , padding , dilation , kernel_size , False
628+ )
629+ )
630+
631+ return input .new_empty (output_size , dtype = input .dtype )
632+
633+
634+ @register_fake ("cadence::quantized_conv_nchw.per_tensor" )
635+ def quantized_conv_nchw_per_tensor_meta (
585636 input : torch .Tensor ,
586637 weight : torch .Tensor ,
587638 bias : torch .Tensor ,
@@ -596,12 +647,8 @@ def quantized_conv_per_tensor_meta(
596647 output_zero_point : int ,
597648 out_multiplier : int ,
598649 out_shift : int ,
599- channel_last : bool = False ,
600650) -> torch .Tensor :
601- if channel_last :
602- out_channels , * kernel_size , _ = weight .shape
603- else :
604- out_channels , _ , * kernel_size = weight .shape
651+ out_channels , _ , * kernel_size = weight .shape
605652
606653 in_size = input .shape
607654 # Assert that the input tensor has at least 3 dimensions, and at most 6
@@ -617,11 +664,55 @@ def quantized_conv_per_tensor_meta(
617664 padding [1 ],
618665 dilation [1 ],
619666 kernel_size [0 ],
620- channel_last ,
667+ False ,
621668 )
622669 if len (in_size ) == 3
623670 else get_conv2d_output_size (
624- in_size , out_channels , stride , padding , dilation , kernel_size , channel_last
671+ in_size , out_channels , stride , padding , dilation , kernel_size , False
672+ )
673+ )
674+
675+ return input .new_empty (output_size , dtype = input .dtype )
676+
677+
678+ @register_fake ("cadence::quantized_conv_nhwc.per_tensor" )
679+ def quantized_conv_nhwc_per_tensor_meta (
680+ input : torch .Tensor ,
681+ weight : torch .Tensor ,
682+ bias : torch .Tensor ,
683+ stride : Tuple [int ],
684+ padding : Tuple [int ],
685+ dilation : Tuple [int ],
686+ groups : int ,
687+ in_zero_point : int ,
688+ weight_zero_point : int ,
689+ bias_scale : float ,
690+ output_scale : float ,
691+ output_zero_point : int ,
692+ out_multiplier : int ,
693+ out_shift : int ,
694+ ) -> torch .Tensor :
695+ out_channels , * kernel_size , _ = weight .shape
696+
697+ in_size = input .shape
698+ # Assert that the input tensor has at least 3 dimensions, and at most 6
699+ assert len (in_size ) > 2
700+ assert len (in_size ) < 6
701+
702+ # Compute the output tensor size
703+ output_size = (
704+ get_conv1d_output_size (
705+ in_size ,
706+ out_channels ,
707+ stride [1 ],
708+ padding [1 ],
709+ dilation [1 ],
710+ kernel_size [0 ],
711+ True ,
712+ )
713+ if len (in_size ) == 3
714+ else get_conv2d_output_size (
715+ in_size , out_channels , stride , padding , dilation , kernel_size , True
625716 )
626717 )
627718
0 commit comments