44
55import  torch 
66import  torch .nn .functional  as  F 
7- from  executorch .backends .xnnpack .utils .utils  import  is_depthwise_conv 
7+ from  executorch .backends .xnnpack .utils .utils  import  (
8+     get_groups_from_conv ,
9+     is_depthwise_conv ,
10+ )
811from  torch ._subclasses  import  FakeTensor 
912from  torch .fx  import  Node 
1013from  torch .fx .passes .utils .matcher_with_name_node_map_utils  import  (
@@ -65,6 +68,28 @@ def decorator(annotator: AnnotatorType) -> None:
6568    return  decorator 
6669
6770
71+ def  change_quantization_config (
72+     original_qspec ,
73+     dtype = None ,
74+     quant_min = None ,
75+     quant_max = None ,
76+     qscheme = None ,
77+     ch_axis = None ,
78+     is_dynamic = None ,
79+     observer_or_fake_quant_ctr = None ,
80+ ):
81+     return  QuantizationSpec (
82+         dtype = dtype  or  original_qspec .dtype ,
83+         quant_min = quant_min  or  original_qspec .quant_min ,
84+         quant_max = quant_max  or  original_qspec .quant_max ,
85+         qscheme = qscheme  or  original_qspec .qscheme ,
86+         ch_axis = ch_axis  or  original_qspec .ch_axis ,
87+         is_dynamic = is_dynamic  or  original_qspec .is_dynamic ,
88+         observer_or_fake_quant_ctr = observer_or_fake_quant_ctr 
89+         or  original_qspec .observer_or_fake_quant_ctr ,
90+     )
91+ 
92+ 
6893def  is_relu_node (node : Node ) ->  bool :
6994    """ 
7095    Check if a given node is a relu node 
@@ -231,31 +256,44 @@ def _do_annotate_conv(
231256            if  is_relu_node (user ):
232257                continue 
233258
259+         # Tracks conditions for whether or not to skip 
260+         skip  =  False 
261+ 
234262        input_qspec_map  =  {}
235263        input_act  =  conv_node .args [0 ]
236264        assert  isinstance (input_act , Node )
237265        input_qspec_map [input_act ] =  get_input_act_qspec (quantization_config )
238266
239267        weight  =  conv_node .args [1 ]
240268        assert  isinstance (weight , Node )
241-         input_qspec_map [weight ] =  get_weight_qspec (quantization_config )
269+         weight_qspec  =  get_weight_qspec (quantization_config )
270+         num_groups  =  get_groups_from_conv (conv_node )
242271
243-         # Only annotate dynamically quantized conv if it's 2D and not depthwise 
244-         if  (
272+         # skip if transposed conv has more than 1 group 
273+         skip  =  skip  or  (is_conv_transpose  and  num_groups  !=  1 )
274+         print (f"{ skip }   conv transpose and num_groups" )
275+ 
276+         if  is_conv_transpose :
277+             # transposed convs per output channel quantization 
278+             weight_qspec  =  change_quantization_config (weight_qspec , ch_axis = 1 )
279+ 
280+         input_qspec_map [weight ] =  weight_qspec 
281+         is_dynamic  =  (
245282            quantization_config 
246283            and  quantization_config .input_activation 
247284            and  quantization_config .input_activation .is_dynamic 
248-         ):
285+         )
286+ 
287+         # Only annotate dynamically quantized conv if it's 2D and not depthwise 
288+         if  is_dynamic :
249289            weight_val  =  weight .meta .get ("val" , None )
250290            weight_shape  =  getattr (weight_val , "shape" , None )
251- 
252291            # Skip if not a 4D weight tensor (i.e. not conv2d) 
253-             if  weight_shape  is  not   None  and  len (weight_shape ) !=  4 :
254-                 continue 
255- 
292+             skip  =  skip  or  (weight_shape  is  not   None  and  len (weight_shape ) !=  4 )
256293            # Skip if depthwise (default to groups=1 since it's not an arg) 
257-             if  is_depthwise_conv (weight_shape , 1 , is_conv_transpose ):
258-                 continue 
294+             skip  =  skip  or  (
295+                 not  is_conv_transpose  and  is_depthwise_conv (weight_shape , 1 , False )
296+             )
259297
260298        # adding weight node to the partition as well 
261299        partition  =  [conv_node , conv_node .args [1 ]]
@@ -265,7 +303,7 @@ def _do_annotate_conv(
265303            input_qspec_map [bias ] =  get_bias_qspec (quantization_config )
266304            partition .append (bias )
267305
268-         if  _is_annotated (partition ):
306+         if  _is_annotated (partition )  or   skip :
269307            continue 
270308
271309        if  filter_fn  and  any (not  filter_fn (n ) for  n  in  partition ):
@@ -311,7 +349,12 @@ def _do_annotate_conv_relu(
311349
312350        weight  =  conv_node .args [1 ]
313351        assert  isinstance (weight , Node )
314-         input_qspec_map [weight ] =  get_weight_qspec (quantization_config )
352+         weight_qspec  =  get_weight_qspec (quantization_config )
353+         groups  =  get_groups_from_conv (conv_node )
354+         if  is_conv_transpose :
355+             # transposed convs per output channel quantization 
356+             weight_qspec  =  change_quantization_config (weight_qspec , ch_axis = 1 )
357+         input_qspec_map [weight ] =  weight_qspec 
315358
316359        # adding weight node to the partition as well 
317360        partition  =  [relu_node , conv_node , conv_node .args [1 ]]
@@ -323,6 +366,9 @@ def _do_annotate_conv_relu(
323366        if  _is_annotated (partition ):
324367            continue 
325368
369+         if  is_conv_transpose  and  groups  !=  1 :
370+             continue 
371+ 
326372        if  filter_fn  and  any (not  filter_fn (n ) for  n  in  partition ):
327373            continue 
328374
0 commit comments