Skip to content

Commit 011cc87

Browse files
cccclaifacebook-github-bot
authored andcommitted
patch qnn silu lowering (#8494)
Summary: There are some edge lowering case missing. Change the logic to look for the op first because it's not decomposed Differential Revision: D69636086
1 parent bc497a0 commit 011cc87

File tree

1 file changed

+7
-11
lines changed

1 file changed

+7
-11
lines changed

backends/qualcomm/_passes/decompose_silu.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,20 @@ def _copy_meta(self, meta: Dict):
2222

2323
def call(self, graph_module: torch.fx.GraphModule):
2424
graph = graph_module.graph
25-
partitions = get_source_partitions(
26-
graph, [torch.nn.functional.silu, torch.ops.aten.silu.default]
27-
)
28-
for _, src_partitions in partitions.items():
29-
for src_partition in src_partitions:
30-
31-
inputs = src_partition.input_nodes
32-
silu_node = src_partition.nodes[0]
33-
with graph_module.graph.inserting_after(inputs[0]):
25+
for node in graph.nodes:
26+
if node.op == "call_function" and node.target == torch.ops.aten.silu.default:
27+
silu_node = node
28+
silu_node_input = node.args[0]
29+
with graph_module.graph.inserting_after(silu_node_input):
3430
sigmoid_node = graph.create_node(
35-
"call_function", torch.ops.aten.sigmoid, (inputs[0],)
31+
"call_function", torch.ops.aten.sigmoid, (silu_node_input,)
3632
)
3733
sigmoid_node.meta = self._copy_meta(silu_node.meta)
3834
with graph_module.graph.inserting_after(sigmoid_node):
3935
mul_node = graph.create_node(
4036
"call_function",
4137
torch.ops.aten.mul,
42-
(inputs[0], sigmoid_node),
38+
(silu_node_input, sigmoid_node),
4339
)
4440
mul_node.meta = self._copy_meta(silu_node.meta)
4541
for user in silu_node.users.copy():

0 commit comments

Comments
 (0)