77
88import torch
99from executorch .exir .pass_base import ExportPass , PassResult
10- from torch .fx .passes .utils .source_matcher_utils import get_source_partitions
1110
1211
1312class DecomposeSilu (ExportPass ):
@@ -22,24 +21,23 @@ def _copy_meta(self, meta: Dict):
2221
2322 def call (self , graph_module : torch .fx .GraphModule ):
2423 graph = graph_module .graph
25- partitions = get_source_partitions (
26- graph , [torch .nn .functional .silu , torch .ops .aten .silu .default ]
27- )
28- for _ , src_partitions in partitions .items ():
29- for src_partition in src_partitions :
30-
31- inputs = src_partition .input_nodes
32- silu_node = src_partition .nodes [0 ]
33- with graph_module .graph .inserting_after (inputs [0 ]):
24+ for node in graph .nodes :
25+ if (
26+ node .op == "call_function"
27+ and node .target == torch .ops .aten .silu .default
28+ ):
29+ silu_node = node
30+ silu_node_input = node .args [0 ]
31+ with graph_module .graph .inserting_after (silu_node_input ):
3432 sigmoid_node = graph .create_node (
35- "call_function" , torch .ops .aten .sigmoid , (inputs [ 0 ] ,)
33+ "call_function" , torch .ops .aten .sigmoid , (silu_node_input ,)
3634 )
3735 sigmoid_node .meta = self ._copy_meta (silu_node .meta )
3836 with graph_module .graph .inserting_after (sigmoid_node ):
3937 mul_node = graph .create_node (
4038 "call_function" ,
4139 torch .ops .aten .mul ,
42- (inputs [ 0 ] , sigmoid_node ),
40+ (silu_node_input , sigmoid_node ),
4341 )
4442 mul_node .meta = self ._copy_meta (silu_node .meta )
4543 for user in silu_node .users .copy ():
0 commit comments