@@ -526,34 +526,14 @@ class FuseCascadedViewOps(ExportPass):
526
526
Fuse a cascaded chain of view ops
527
527
"""
528
528
529
- # Find a chain of view ops, and fuse them into a single permute op.
530
-
531
529
def fuse_cascaded_view_ops (self , graph_module : torch .fx .GraphModule ):
532
- graph = graph_module .graph
533
- for node in graph .nodes :
534
- # We are only interested in view ops
535
- if node .target != exir_ops .edge .aten .view_copy .default :
536
- continue
537
-
538
- # Get the cascaded chain of view ops starting at node
539
- cascaded_view_ops = get_cascaded_ops (
540
- [node ], [exir_ops .edge .aten .view_copy .default ]
541
- )
542
- # The chain must have more than 1 node
543
- if len (cascaded_view_ops ) == 1 :
530
+ view_target = exir_ops .edge .aten .view_copy .default
531
+ for view_node in graph_module .graph .find_nodes (op = "call_function" , target = view_target , sort = True ):
532
+ input_view = view_node .args [0 ]
533
+ if input_view .op != "call_function" or input_view .target != view_target :
544
534
continue
545
535
546
- last_view_node = cascaded_view_ops [- 1 ]
547
- with graph .inserting_before (last_view_node ):
548
- new_view = graph .call_function (
549
- exir_ops .edge .aten .view_copy .default ,
550
- args = (node .args [0 ], last_view_node .args [1 ]),
551
- )
552
- last_view_node .replace_all_uses_with (new_view )
553
-
554
- # Now erase the chain
555
- for v in reversed (cascaded_view_ops ):
556
- graph .erase_node (v )
536
+ view_node .replace_input_with (input_view , input_view .args [0 ])
557
537
558
538
graph_module .recompile ()
559
539
0 commit comments