66from copy import copy
77
88import torch
9+ from executorch .backends .arm .tosa_quant_utils import QuantArgs
910from executorch .exir .dialects ._ops import ops as exir_ops
1011from executorch .exir .pass_base import ExportPass
1112
@@ -48,7 +49,40 @@ def _get_decomposition(op):
4849 torch .ops .aten .cat .default ,
4950 )
5051 case _:
51- raise RuntimeError ("Unvalid op for grouped conv decomposition." )
52+ raise RuntimeError ("Invalid op for grouped conv decomposition" )
53+
54+ @staticmethod
55+ def _split_per_channel_qparams (qarg , index , output_slice_size ):
56+ if qarg is not None and qarg .per_channel :
57+ start_index = index * output_slice_size
58+ stop_index = (index + 1 ) * output_slice_size
59+ return QuantArgs (
60+ scale = qarg .scale [start_index :stop_index ],
61+ zp = qarg .zp [start_index :stop_index ],
62+ qmin = qarg .qmin ,
63+ qmax = qarg .qmax ,
64+ dtype = qarg .dtype ,
65+ axis = qarg .axis ,
66+ per_channel = qarg .per_channel ,
67+ )
68+ return qarg
69+
70+ @staticmethod
71+ def _get_meta_copy (meta , i , output_slice_size ):
72+ meta_copy = meta .copy ()
73+ if "input_qparams" in meta .data and len (meta .data ["input_qparams" ]) > 0 :
74+ # Handle per-channel quantization by splitting quantization params
75+ # similarly to how activations/weights/biases are split.
76+ new_qparams = meta .data .get ("input_qparams" ).copy ()
77+ # Get quantization params of the weights and slice them.
78+ qarg = new_qparams [1 ]
79+ new_qparams [1 ] = DecomposeGroupedConv ._split_per_channel_qparams (
80+ qarg , index = i , output_slice_size = output_slice_size
81+ )
82+
83+ meta_copy .data ["input_qparams" ] = new_qparams
84+
85+ return meta_copy
5286
5387 def call_operator (self , op , args , kwargs , meta ):
5488 if op == exir_ops .edge .aten .convolution .default :
@@ -105,7 +139,6 @@ def call_operator(self, op, args, kwargs, meta):
105139 if bias_node is None :
106140 bias_slices .append (None )
107141 else :
108-
109142 start_index = i * output_slice_size
110143 stop_index = (i + 1 ) * output_slice_size
111144 slice_args = (bias_node , 0 , start_index , stop_index )
@@ -115,20 +148,23 @@ def call_operator(self, op, args, kwargs, meta):
115148 )
116149
117150 output_slices = []
118- for input_slice , filter_slice , bias_slice in zip (
119- input_slices , filter_slices , bias_slices
151+ for i , ( input_slice , filter_slice , bias_slice ) in enumerate (
152+ zip ( input_slices , filter_slices , bias_slices )
120153 ):
121154
155+ meta_copy = DecomposeGroupedConv ._get_meta_copy (meta , i , output_slice_size )
156+
122157 if op == exir_ops .edge .aten .convolution .default :
123158 conv_args = (input_slice , filter_slice , bias_slice , * args [3 :8 ], 1 )
124159 elif op == torch .ops .aten .conv2d .default :
125160 conv_args = (input_slice , filter_slice , bias_slice , * args [3 :6 ], 1 )
126161 else :
127- raise RuntimeError ("Unvalid op for grouped conv decomposition. " )
162+ raise RuntimeError ("Invalid op for grouped conv decomposition" )
128163
129164 output_slices .append (
130- super ().call_operator (conv_op , conv_args , kwargs , meta )
165+ super ().call_operator (conv_op , conv_args , kwargs , meta_copy )
131166 )
132167
133168 cat_args = (output_slices , 1 )
134- return super ().call_operator (cat_op , cat_args , kwargs , no_q_dq_meta )
169+ # propagate original metadata (including quantization params) to the concatenated output
170+ return super ().call_operator (cat_op , cat_args , kwargs , meta )
0 commit comments