85
85
)
86
86
87
87
lib .define (
88
- "quantized_conv (Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False ) -> (Tensor Z)"
88
+ "quantized_conv_nhwc (Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Z)"
89
89
)
90
90
lib .define (
91
- "quantized_conv .out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False , *, Tensor(a!) out) -> Tensor(a!)"
91
+ "quantized_conv_nhwc .out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)"
92
92
)
93
93
lib .define (
94
- "quantized_conv .per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False ) -> (Tensor Z)"
94
+ "quantized_conv_nhwc .per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
95
95
)
96
96
lib .define (
97
- "quantized_conv.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
97
+ "quantized_conv_nhwc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
98
+ )
99
+ lib .define (
100
+ "quantized_conv_nchw(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Z)"
101
+ )
102
+ lib .define (
103
+ "quantized_conv_nchw.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)"
104
+ )
105
+ lib .define (
106
+ "quantized_conv_nchw.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
107
+ )
108
+ lib .define (
109
+ "quantized_conv_nchw.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
98
110
)
99
-
100
111
lib .define (
101
112
"quantized_matmul(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False) -> (Tensor Z)"
102
113
)
@@ -532,8 +543,8 @@ def quantized_linear_asym8uxasym8u_asym8u_per_tensor_meta(
532
543
return src .new_empty (out_size , dtype = src .dtype )
533
544
534
545
535
- @register_fake ("cadence::quantized_conv " )
536
- def quantized_conv_meta (
546
+ @register_fake ("cadence::quantized_conv_nhwc " )
547
+ def quantized_conv_nhwc_meta (
537
548
input : torch .Tensor ,
538
549
weight : torch .Tensor ,
539
550
bias : torch .Tensor ,
@@ -548,12 +559,8 @@ def quantized_conv_meta(
548
559
output_zero_point : int ,
549
560
out_multiplier : torch .Tensor ,
550
561
out_shift : torch .Tensor ,
551
- channel_last : bool = False ,
552
562
) -> torch .Tensor :
553
- if channel_last :
554
- out_channels , * kernel_size , _ = weight .shape
555
- else :
556
- out_channels , _ , * kernel_size = weight .shape
563
+ out_channels , * kernel_size , _ = weight .shape
557
564
558
565
in_size = input .shape
559
566
# Assert that the input tensor has at least 3 dimensions, and at most 6
@@ -569,19 +576,63 @@ def quantized_conv_meta(
569
576
padding [1 ],
570
577
dilation [1 ],
571
578
kernel_size [0 ],
572
- channel_last ,
579
+ True ,
573
580
)
574
581
if len (in_size ) == 3
575
582
else get_conv2d_output_size (
576
- in_size , out_channels , stride , padding , dilation , kernel_size , channel_last
583
+ in_size , out_channels , stride , padding , dilation , kernel_size , True
577
584
)
578
585
)
579
586
580
587
return input .new_empty (output_size , dtype = input .dtype )
581
588
582
589
583
- @register_fake ("cadence::quantized_conv.per_tensor" )
584
- def quantized_conv_per_tensor_meta (
590
+ @register_fake ("cadence::quantized_conv_nchw" )
591
+ def quantized_conv_nchw_meta (
592
+ input : torch .Tensor ,
593
+ weight : torch .Tensor ,
594
+ bias : torch .Tensor ,
595
+ stride : Tuple [int ],
596
+ padding : Tuple [int ],
597
+ dilation : Tuple [int ],
598
+ groups : int ,
599
+ in_zero_point : int ,
600
+ weight_zero_point : torch .Tensor ,
601
+ bias_scale : torch .Tensor ,
602
+ output_scale : float ,
603
+ output_zero_point : int ,
604
+ out_multiplier : torch .Tensor ,
605
+ out_shift : torch .Tensor ,
606
+ ) -> torch .Tensor :
607
+ out_channels , _ , * kernel_size = weight .shape
608
+
609
+ in_size = input .shape
610
+ # Assert that the input tensor has at least 3 dimensions, and at most 6
611
+ assert len (in_size ) > 2
612
+ assert len (in_size ) < 6
613
+
614
+ # Compute the output tensor size
615
+ output_size = (
616
+ get_conv1d_output_size (
617
+ in_size ,
618
+ out_channels ,
619
+ stride [1 ],
620
+ padding [1 ],
621
+ dilation [1 ],
622
+ kernel_size [0 ],
623
+ False ,
624
+ )
625
+ if len (in_size ) == 3
626
+ else get_conv2d_output_size (
627
+ in_size , out_channels , stride , padding , dilation , kernel_size , False
628
+ )
629
+ )
630
+
631
+ return input .new_empty (output_size , dtype = input .dtype )
632
+
633
+
634
+ @register_fake ("cadence::quantized_conv_nchw.per_tensor" )
635
+ def quantized_conv_nchw_per_tensor_meta (
585
636
input : torch .Tensor ,
586
637
weight : torch .Tensor ,
587
638
bias : torch .Tensor ,
@@ -596,12 +647,8 @@ def quantized_conv_per_tensor_meta(
596
647
output_zero_point : int ,
597
648
out_multiplier : int ,
598
649
out_shift : int ,
599
- channel_last : bool = False ,
600
650
) -> torch .Tensor :
601
- if channel_last :
602
- out_channels , * kernel_size , _ = weight .shape
603
- else :
604
- out_channels , _ , * kernel_size = weight .shape
651
+ out_channels , _ , * kernel_size = weight .shape
605
652
606
653
in_size = input .shape
607
654
# Assert that the input tensor has at least 3 dimensions, and at most 6
@@ -617,11 +664,55 @@ def quantized_conv_per_tensor_meta(
617
664
padding [1 ],
618
665
dilation [1 ],
619
666
kernel_size [0 ],
620
- channel_last ,
667
+ False ,
621
668
)
622
669
if len (in_size ) == 3
623
670
else get_conv2d_output_size (
624
- in_size , out_channels , stride , padding , dilation , kernel_size , channel_last
671
+ in_size , out_channels , stride , padding , dilation , kernel_size , False
672
+ )
673
+ )
674
+
675
+ return input .new_empty (output_size , dtype = input .dtype )
676
+
677
+
678
+ @register_fake ("cadence::quantized_conv_nhwc.per_tensor" )
679
+ def quantized_conv_nhwc_per_tensor_meta (
680
+ input : torch .Tensor ,
681
+ weight : torch .Tensor ,
682
+ bias : torch .Tensor ,
683
+ stride : Tuple [int ],
684
+ padding : Tuple [int ],
685
+ dilation : Tuple [int ],
686
+ groups : int ,
687
+ in_zero_point : int ,
688
+ weight_zero_point : int ,
689
+ bias_scale : float ,
690
+ output_scale : float ,
691
+ output_zero_point : int ,
692
+ out_multiplier : int ,
693
+ out_shift : int ,
694
+ ) -> torch .Tensor :
695
+ out_channels , * kernel_size , _ = weight .shape
696
+
697
+ in_size = input .shape
698
+ # Assert that the input tensor has at least 3 dimensions, and at most 6
699
+ assert len (in_size ) > 2
700
+ assert len (in_size ) < 6
701
+
702
+ # Compute the output tensor size
703
+ output_size = (
704
+ get_conv1d_output_size (
705
+ in_size ,
706
+ out_channels ,
707
+ stride [1 ],
708
+ padding [1 ],
709
+ dilation [1 ],
710
+ kernel_size [0 ],
711
+ True ,
712
+ )
713
+ if len (in_size ) == 3
714
+ else get_conv2d_output_size (
715
+ in_size , out_channels , stride , padding , dilation , kernel_size , True
625
716
)
626
717
)
627
718
0 commit comments