44
55import torch
66import torch .nn .functional as F
7- from executorch .backends .xnnpack .utils .utils import is_depthwise_conv
7+ from executorch .backends .xnnpack .utils .utils import (
8+ get_groups_from_conv ,
9+ is_depthwise_conv ,
10+ )
811from torch ._subclasses import FakeTensor
912from torch .fx import Node
1013from torch .fx .passes .utils .matcher_with_name_node_map_utils import (
@@ -65,6 +68,28 @@ def decorator(annotator: AnnotatorType) -> None:
6568 return decorator
6669
6770
71+ def change_quantization_config (
72+ original_qspec ,
73+ dtype = None ,
74+ quant_min = None ,
75+ quant_max = None ,
76+ qscheme = None ,
77+ ch_axis = None ,
78+ is_dynamic = None ,
79+ observer_or_fake_quant_ctr = None ,
80+ ):
81+ return QuantizationSpec (
82+ dtype = dtype or original_qspec .dtype ,
83+ quant_min = quant_min or original_qspec .quant_min ,
84+ quant_max = quant_max or original_qspec .quant_max ,
85+ qscheme = qscheme or original_qspec .qscheme ,
86+ ch_axis = ch_axis or original_qspec .ch_axis ,
87+ is_dynamic = is_dynamic or original_qspec .is_dynamic ,
88+ observer_or_fake_quant_ctr = observer_or_fake_quant_ctr
89+ or original_qspec .observer_or_fake_quant_ctr ,
90+ )
91+
92+
6893def is_relu_node (node : Node ) -> bool :
6994 """
7095 Check if a given node is a relu node
@@ -231,31 +256,44 @@ def _do_annotate_conv(
231256 if is_relu_node (user ):
232257 continue
233258
259+ # Tracks conditions for whether or not to skip
260+ skip = False
261+
234262 input_qspec_map = {}
235263 input_act = conv_node .args [0 ]
236264 assert isinstance (input_act , Node )
237265 input_qspec_map [input_act ] = get_input_act_qspec (quantization_config )
238266
239267 weight = conv_node .args [1 ]
240268 assert isinstance (weight , Node )
241- input_qspec_map [weight ] = get_weight_qspec (quantization_config )
269+ weight_qspec = get_weight_qspec (quantization_config )
270+ num_groups = get_groups_from_conv (conv_node )
242271
243- # Only annotate dynamically quantized conv if it's 2D and not depthwise
244- if (
272+ # skip if transposed conv has more than 1 group
273+ skip = skip or (is_conv_transpose and num_groups != 1 )
274+ print (f"{ skip } conv transpose and num_groups" )
275+
276+ if is_conv_transpose :
277+ # transposed convs per output channel quantization
278+ weight_qspec = change_quantization_config (weight_qspec , ch_axis = 1 )
279+
280+ input_qspec_map [weight ] = weight_qspec
281+ is_dynamic = (
245282 quantization_config
246283 and quantization_config .input_activation
247284 and quantization_config .input_activation .is_dynamic
248- ):
285+ )
286+
287+ # Only annotate dynamically quantized conv if it's 2D and not depthwise
288+ if is_dynamic :
249289 weight_val = weight .meta .get ("val" , None )
250290 weight_shape = getattr (weight_val , "shape" , None )
251-
252291 # Skip if not a 4D weight tensor (i.e. not conv2d)
253- if weight_shape is not None and len (weight_shape ) != 4 :
254- continue
255-
292+ skip = skip or (weight_shape is not None and len (weight_shape ) != 4 )
256293 # Skip if depthwise (default to groups=1 since it's not an arg)
257- if is_depthwise_conv (weight_shape , 1 , is_conv_transpose ):
258- continue
294+ skip = skip or (
295+ not is_conv_transpose and is_depthwise_conv (weight_shape , 1 , False )
296+ )
259297
260298 # adding weight node to the partition as well
261299 partition = [conv_node , conv_node .args [1 ]]
@@ -265,7 +303,7 @@ def _do_annotate_conv(
265303 input_qspec_map [bias ] = get_bias_qspec (quantization_config )
266304 partition .append (bias )
267305
268- if _is_annotated (partition ):
306+ if _is_annotated (partition ) or skip :
269307 continue
270308
271309 if filter_fn and any (not filter_fn (n ) for n in partition ):
@@ -311,7 +349,12 @@ def _do_annotate_conv_relu(
311349
312350 weight = conv_node .args [1 ]
313351 assert isinstance (weight , Node )
314- input_qspec_map [weight ] = get_weight_qspec (quantization_config )
352+ weight_qspec = get_weight_qspec (quantization_config )
353+ groups = get_groups_from_conv (conv_node )
354+ if is_conv_transpose :
355+ # transposed convs per output channel quantization
356+ weight_qspec = change_quantization_config (weight_qspec , ch_axis = 1 )
357+ input_qspec_map [weight ] = weight_qspec
315358
316359 # adding weight node to the partition as well
317360 partition = [relu_node , conv_node , conv_node .args [1 ]]
@@ -323,6 +366,9 @@ def _do_annotate_conv_relu(
323366 if _is_annotated (partition ):
324367 continue
325368
369+ if is_conv_transpose and groups != 1 :
370+ continue
371+
326372 if filter_fn and any (not filter_fn (n ) for n in partition ):
327373 continue
328374
0 commit comments