@@ -471,7 +471,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
471
471
pattern .partition_types (),
472
472
)
473
473
for fused_partition in fused_partitions :
474
- anchors = pattern .get_anchors (graph_module , fused_partition )
474
+ anchors , op_node = pattern .get_anchors (graph_module , fused_partition )
475
475
if not anchors or anchors .empty :
476
476
continue
477
477
if any (self .is_fused (p .nodes ) for p in fused_partition ):
@@ -512,13 +512,10 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
512
512
bias_inputs = [node .args [0 ] for node in dequants_biases ]
513
513
other_inputs = [node .args [idx ] for node , idx in anchors .others ]
514
514
515
- # The node is the first index of the list and first of the tuple
516
- anchor_output_node = anchors . output [ 0 ] [0 ]
515
+ assert op_node is not None , "op_node is None"
516
+ quant_node = list ( op_node . users . keys ()) [0 ]
517
517
518
- assert len (anchor_output_node .users ) == 1
519
- quant_node = list (anchor_output_node .users .keys ())[0 ]
520
-
521
- with graph_module .graph .inserting_after (anchor_output_node ):
518
+ with graph_module .graph .inserting_after (op_node ):
522
519
args = tuple (
523
520
inputs_inputs + weights_inputs + other_inputs + bias_inputs
524
521
)
@@ -532,7 +529,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
532
529
)
533
530
elif isinstance (pattern , CatPattern ):
534
531
args , kwargs = get_args_and_kwargs_cat (
535
- inputs_inputs , other_inputs , anchor_output_node
532
+ inputs_inputs , other_inputs , op_node
536
533
)
537
534
elif isinstance (pattern , ConvReluPatterns ):
538
535
# For ConvReLU, we are fusing Conv+ReLU
@@ -563,7 +560,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
563
560
dequants_weights ,
564
561
bias_inputs ,
565
562
quant_node ,
566
- anchor_output_node ,
563
+ op_node ,
567
564
)
568
565
elif isinstance (pattern , LinearPattern ):
569
566
args , kwargs = get_args_and_kwargs_linear (
@@ -618,20 +615,28 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
618
615
inputs_inputs ,
619
616
dequants_inputs ,
620
617
quant_node ,
621
- anchor_output_node ,
618
+ op_node ,
622
619
)
620
+
623
621
fused = graph_module .graph .call_function (
624
622
pattern .replacement_op (),
625
623
args ,
626
624
kwargs ,
627
625
)
628
- fused .meta = quant_node .meta
629
- quant_node .replace_all_uses_with (fused )
626
+
627
+ if len (anchors .output ) > 0 :
628
+ fused .meta = quant_node .meta
629
+ quant_node .replace_all_uses_with (fused )
630
+ else :
631
+ fused .meta = op_node .meta
632
+ op_node .replace_all_uses_with (fused )
633
+ if op_node .op == "output" :
634
+ _ = graph_module .graph .output ((fused ,))
630
635
631
636
legalize_graph (graph_module )
632
637
graph_module .graph .eliminate_dead_code ()
633
- # pyre-fixme[7]: Incompatible return type
634
638
graph_module .recompile ()
639
+ return PassResult (graph_module , True )
635
640
636
641
@classmethod
637
642
# pyre-ignore[2]: Parameter `nodes` has no type specified
0 commit comments