@@ -471,7 +471,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
471471 pattern .partition_types (),
472472 )
473473 for fused_partition in fused_partitions :
474- anchors = pattern .get_anchors (graph_module , fused_partition )
474+ anchors , op_node = pattern .get_anchors (graph_module , fused_partition )
475475 if not anchors or anchors .empty :
476476 continue
477477 if any (self .is_fused (p .nodes ) for p in fused_partition ):
@@ -512,13 +512,10 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
512512 bias_inputs = [node .args [0 ] for node in dequants_biases ]
513513 other_inputs = [node .args [idx ] for node , idx in anchors .others ]
514514
515- # The node is the first index of the list and first of the tuple
516- anchor_output_node = anchors . output [ 0 ] [0 ]
515+ assert op_node is not None , "op_node is None"
516+ quant_node = list ( op_node . users . keys ()) [0 ]
517517
518- assert len (anchor_output_node .users ) == 1
519- quant_node = list (anchor_output_node .users .keys ())[0 ]
520-
521- with graph_module .graph .inserting_after (anchor_output_node ):
518+ with graph_module .graph .inserting_after (op_node ):
522519 args = tuple (
523520 inputs_inputs + weights_inputs + other_inputs + bias_inputs
524521 )
@@ -532,7 +529,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
532529 )
533530 elif isinstance (pattern , CatPattern ):
534531 args , kwargs = get_args_and_kwargs_cat (
535- inputs_inputs , other_inputs , anchor_output_node
532+ inputs_inputs , other_inputs , op_node
536533 )
537534 elif isinstance (pattern , ConvReluPatterns ):
538535 # For ConvReLU, we are fusing Conv+ReLU
@@ -563,7 +560,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
563560 dequants_weights ,
564561 bias_inputs ,
565562 quant_node ,
566- anchor_output_node ,
563+ op_node ,
567564 )
568565 elif isinstance (pattern , LinearPattern ):
569566 args , kwargs = get_args_and_kwargs_linear (
@@ -618,20 +615,28 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
618615 inputs_inputs ,
619616 dequants_inputs ,
620617 quant_node ,
621- anchor_output_node ,
618+ op_node ,
622619 )
620+
623621 fused = graph_module .graph .call_function (
624622 pattern .replacement_op (),
625623 args ,
626624 kwargs ,
627625 )
628- fused .meta = quant_node .meta
629- quant_node .replace_all_uses_with (fused )
626+
627+ if len (anchors .output ) > 0 :
628+ fused .meta = quant_node .meta
629+ quant_node .replace_all_uses_with (fused )
630+ else :
631+ fused .meta = op_node .meta
632+ op_node .replace_all_uses_with (fused )
633+ if op_node .op == "output" :
634+ _ = graph_module .graph .output ((fused ,))
630635
631636 legalize_graph (graph_module )
632637 graph_module .graph .eliminate_dead_code ()
633- # pyre-fixme[7]: Incompatible return type
634638 graph_module .recompile ()
639+ return PassResult (graph_module , True )
635640
636641 @classmethod
637642 # pyre-ignore[2]: Parameter `nodes` has no type specified
0 commit comments