Skip to content

Commit 9032d77

Browse files
authored
patch qnn silu lowering
Differential Revision: D69636086 Pull Request resolved: pytorch/executorch#8494
1 parent 7964bca commit 9032d77

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

backends/qualcomm/_passes/decompose_silu.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import torch
99
from executorch.exir.pass_base import ExportPass, PassResult
10-
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1110

1211

1312
class DecomposeSilu(ExportPass):
@@ -22,24 +21,23 @@ def _copy_meta(self, meta: Dict):
2221

2322
def call(self, graph_module: torch.fx.GraphModule):
2423
graph = graph_module.graph
25-
partitions = get_source_partitions(
26-
graph, [torch.nn.functional.silu, torch.ops.aten.silu.default]
27-
)
28-
for _, src_partitions in partitions.items():
29-
for src_partition in src_partitions:
30-
31-
inputs = src_partition.input_nodes
32-
silu_node = src_partition.nodes[0]
33-
with graph_module.graph.inserting_after(inputs[0]):
24+
for node in graph.nodes:
25+
if (
26+
node.op == "call_function"
27+
and node.target == torch.ops.aten.silu.default
28+
):
29+
silu_node = node
30+
silu_node_input = node.args[0]
31+
with graph_module.graph.inserting_after(silu_node_input):
3432
sigmoid_node = graph.create_node(
35-
"call_function", torch.ops.aten.sigmoid, (inputs[0],)
33+
"call_function", torch.ops.aten.sigmoid, (silu_node_input,)
3634
)
3735
sigmoid_node.meta = self._copy_meta(silu_node.meta)
3836
with graph_module.graph.inserting_after(sigmoid_node):
3937
mul_node = graph.create_node(
4038
"call_function",
4139
torch.ops.aten.mul,
42-
(inputs[0], sigmoid_node),
40+
(silu_node_input, sigmoid_node),
4341
)
4442
mul_node.meta = self._copy_meta(silu_node.meta)
4543
for user in silu_node.users.copy():

0 commit comments

Comments
 (0)