6
6
from copy import copy
7
7
8
8
import torch
9
+ from executorch .backends .arm .tosa_quant_utils import QuantArgs
9
10
from executorch .exir .dialects ._ops import ops as exir_ops
10
11
from executorch .exir .pass_base import ExportPass
11
12
@@ -48,7 +49,40 @@ def _get_decomposition(op):
48
49
torch .ops .aten .cat .default ,
49
50
)
50
51
case _:
51
- raise RuntimeError ("Unvalid op for grouped conv decomposition." )
52
+ raise RuntimeError ("Invalid op for grouped conv decomposition" )
53
+
54
+ @staticmethod
55
+ def _split_per_channel_qparams (qarg , index , output_slice_size ):
56
+ if qarg is not None and qarg .per_channel :
57
+ start_index = index * output_slice_size
58
+ stop_index = (index + 1 ) * output_slice_size
59
+ return QuantArgs (
60
+ scale = qarg .scale [start_index :stop_index ],
61
+ zp = qarg .zp [start_index :stop_index ],
62
+ qmin = qarg .qmin ,
63
+ qmax = qarg .qmax ,
64
+ dtype = qarg .dtype ,
65
+ axis = qarg .axis ,
66
+ per_channel = qarg .per_channel ,
67
+ )
68
+ return qarg
69
+
70
+ @staticmethod
71
+ def _get_meta_copy (meta , i , output_slice_size ):
72
+ meta_copy = meta .copy ()
73
+ if "input_qparams" in meta .data and len (meta .data ["input_qparams" ]) > 0 :
74
+ # Handle per-channel quantization by splitting quantization params
75
+ # similarly to how activations/weights/biases are split.
76
+ new_qparams = meta .data .get ("input_qparams" ).copy ()
77
+ # Get quantization params of the weights and slice them.
78
+ qarg = new_qparams [1 ]
79
+ new_qparams [1 ] = DecomposeGroupedConv ._split_per_channel_qparams (
80
+ qarg , index = i , output_slice_size = output_slice_size
81
+ )
82
+
83
+ meta_copy .data ["input_qparams" ] = new_qparams
84
+
85
+ return meta_copy
52
86
53
87
def call_operator (self , op , args , kwargs , meta ):
54
88
if op == exir_ops .edge .aten .convolution .default :
@@ -105,7 +139,6 @@ def call_operator(self, op, args, kwargs, meta):
105
139
if bias_node is None :
106
140
bias_slices .append (None )
107
141
else :
108
-
109
142
start_index = i * output_slice_size
110
143
stop_index = (i + 1 ) * output_slice_size
111
144
slice_args = (bias_node , 0 , start_index , stop_index )
@@ -115,20 +148,23 @@ def call_operator(self, op, args, kwargs, meta):
115
148
)
116
149
117
150
output_slices = []
118
- for input_slice , filter_slice , bias_slice in zip (
119
- input_slices , filter_slices , bias_slices
151
+ for i , ( input_slice , filter_slice , bias_slice ) in enumerate (
152
+ zip ( input_slices , filter_slices , bias_slices )
120
153
):
121
154
155
+ meta_copy = DecomposeGroupedConv ._get_meta_copy (meta , i , output_slice_size )
156
+
122
157
if op == exir_ops .edge .aten .convolution .default :
123
158
conv_args = (input_slice , filter_slice , bias_slice , * args [3 :8 ], 1 )
124
159
elif op == torch .ops .aten .conv2d .default :
125
160
conv_args = (input_slice , filter_slice , bias_slice , * args [3 :6 ], 1 )
126
161
else :
127
- raise RuntimeError ("Unvalid op for grouped conv decomposition. " )
162
+ raise RuntimeError ("Invalid op for grouped conv decomposition" )
128
163
129
164
output_slices .append (
130
- super ().call_operator (conv_op , conv_args , kwargs , meta )
165
+ super ().call_operator (conv_op , conv_args , kwargs , meta_copy )
131
166
)
132
167
133
168
cat_args = (output_slices , 1 )
134
- return super ().call_operator (cat_op , cat_args , kwargs , no_q_dq_meta )
169
+ # propagate original metadata (including quantization params) to the concatenated output
170
+ return super ().call_operator (cat_op , cat_args , kwargs , meta )
0 commit comments