@@ -91,6 +91,16 @@ class OperatorConfig(NamedTuple):
9191 operators : list [OperatorPatternType ]
9292
9393
94+ def is_relu_node (node : Node ) -> bool :
95+ """
96+ Check if a given node is a relu node
97+ """
98+ return node .op == "call_function" and node .target in [
99+ torch .ops .aten .relu .default ,
100+ torch .ops .aten .relu_ .default ,
101+ ]
102+
103+
94104def _is_annotated (nodes : list [Node ]):
95105 """
96106 Given a list of nodes (that represents an operator pattern),
@@ -231,10 +241,7 @@ def _annotate_linear_relu(
231241 weight_qspec = get_weight_qspec (quantization_config )
232242 bias_qspec = get_bias_qspec (quantization_config )
233243 for node in gm .graph .nodes :
234- if node .op != "call_function" or node .target not in [
235- torch .ops .aten .relu .default ,
236- torch .ops .aten .relu_ .default ,
237- ]:
244+ if not is_relu_node (node ):
238245 continue
239246 relu_node = node
240247 maybe_linear_node = node .args [0 ]
@@ -285,21 +292,28 @@ def _annotate_linear_relu(
285292 return annotated_partitions
286293
287294
288- @register_annotator ("conv" )
289- def _annotate_conv (
295+ def _do_annotate_conv (
290296 gm : torch .fx .GraphModule ,
291297 quantization_config : Optional [QuantizationConfig ],
292298 filter_fn : Optional [Callable [[Node ], bool ]] = None ,
299+ is_conv_transpose : bool = False ,
293300) -> Optional [list [list [Node ]]]:
294301 annotated_partitions = []
302+ is_conv_node = _is_conv_transpose_node if is_conv_transpose else _is_conv_node
303+
295304 for n in gm .graph .nodes :
296- if n .op != "call_function" or n .target not in [
297- torch .ops .aten .conv1d .default ,
298- torch .ops .aten .conv2d .default ,
299- ]:
305+ if not is_conv_node (n ):
300306 continue
301307 conv_node = n
302308
309+ # This is hacky!
310+ # We do not want to annotate conv node independently if there is a conv + relu pattern
311+ # So we skip if the conv node is consumed by a single relu node
312+ if len (conv_node .users ) == 1 :
313+ user = list (conv_node .users .keys ())[0 ]
314+ if is_relu_node (user ):
315+ continue
316+
303317 input_qspec_map = {}
304318 input_act = conv_node .args [0 ]
305319 assert isinstance (input_act , Node )
@@ -341,10 +355,7 @@ def _do_annotate_conv_relu(
341355):
342356 annotated_partitions = []
343357 for n in gm .graph .nodes :
344- if n .op != "call_function" or n .target not in [
345- torch .ops .aten .relu .default ,
346- torch .ops .aten .relu_ .default ,
347- ]:
358+ if not is_relu_node (n ):
348359 continue
349360 relu_node = n
350361 maybe_conv_node = n .args [0 ]
@@ -393,6 +404,26 @@ def _do_annotate_conv_relu(
393404 return annotated_partitions
394405
395406
407+ @register_annotator ("conv" )
408+ def _annotate_conv (
409+ gm : torch .fx .GraphModule ,
410+ quantization_config : Optional [QuantizationConfig ],
411+ filter_fn : Optional [Callable [[Node ], bool ]] = None ,
412+ ) -> Optional [list [list [Node ]]]:
413+ return _do_annotate_conv (
414+ gm , quantization_config , filter_fn , is_conv_transpose = False
415+ )
416+
417+
418+ @register_annotator ("conv_transpose" )
419+ def _annotate_transpose_conv (
420+ gm : torch .fx .GraphModule ,
421+ quantization_config : Optional [QuantizationConfig ],
422+ filter_fn : Optional [Callable [[Node ], bool ]] = None ,
423+ ) -> Optional [list [list [Node ]]]:
424+ return _do_annotate_conv (gm , quantization_config , filter_fn , is_conv_transpose = True )
425+
426+
396427@register_annotator ("conv_relu" )
397428def _annotate_conv_relu (
398429 gm : torch .fx .GraphModule ,
@@ -744,10 +775,7 @@ def _annotate_add_relu( # noqa: C901
744775) -> Optional [list [list [Node ]]]:
745776 annotated_partitions = []
746777 for node in gm .graph .nodes :
747- if node .op != "call_function" or node .target not in [
748- torch .ops .aten .relu .default ,
749- torch .ops .aten .relu_ .default ,
750- ]:
778+ if not is_relu_node (node ):
751779 continue
752780 relu_node = node
753781 maybe_add = node .args [0 ]
@@ -872,10 +900,7 @@ def _annotate_mul_relu( # noqa: C901
872900) -> Optional [list [list [Node ]]]:
873901 annotated_partitions = []
874902 for node in gm .graph .nodes :
875- if node .op != "call_function" or node .target not in [
876- torch .ops .aten .relu .default ,
877- torch .ops .aten .relu_ .default ,
878- ]:
903+ if not is_relu_node (node ):
879904 continue
880905 relu_node = node
881906 maybe_mul = node .args [0 ]
0 commit comments