Skip to content

Commit b914e66

Browse files
authored
Qualcomm AI Engine Direct - Replace get_source_partition in FixedLinearKeepDim (#13213)
- Enumerate all nodes in the graph to find linear node instead of using get_source_partition
1 parent 9b0feb0 commit b914e66

File tree

1 file changed

+51
-56
lines changed

1 file changed

+51
-56
lines changed

backends/qualcomm/_passes/fixed_linear_keep_dim.py

Lines changed: 51 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
from executorch.exir.pass_base import ExportPass, PassResult
1010
from executorch.exir.passes import dead_code_elimination_pass
1111

12-
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
13-
1412

1513
class FixedLinearKeepDim(ExportPass):
1614
"""
@@ -24,61 +22,58 @@ def __init__(self):
2422
super(FixedLinearKeepDim, self).__init__()
2523

2624
def _fixed_keep_dim(self, graph_module: torch.fx.GraphModule):
27-
partitions = get_source_partitions(
28-
graph_module.graph, [torch.nn.Linear, torch.ops.aten.linear.default]
29-
)
30-
for _, src_partitions in partitions.items():
31-
for src_partition in src_partitions:
32-
linear_node = [
33-
n for n in src_partition.nodes if n.target == self.linear
34-
][0]
35-
input_node = linear_node.args[0]
36-
# Since QNN has no keep dims for linear op, we will need to add squeeze and unsqueeze around linear node
37-
# TODO: Find a more general conditional statement.
38-
linear_output = linear_node.meta["val"]
39-
if linear_output.dim() >= 3:
40-
with graph_module.graph.inserting_after(input_node):
41-
input_users = list(input_node.users.keys())
42-
input_tensor = input_node.meta["val"]
43-
squeeze_dim = (-1, input_tensor.shape[-1])
44-
squeeze_node = graph_module.graph.create_node(
45-
"call_function",
46-
self.view_copy,
47-
(
48-
input_node,
49-
squeeze_dim,
50-
),
51-
)
52-
# meta needs to be copied elementwisely for fake-tensor
53-
# to be updated correctly and not affect meta of input_node
54-
for k, v in input_node.meta.items():
55-
squeeze_node.meta[k] = v
56-
squeeze_node.meta["val"] = input_tensor.reshape(squeeze_dim)
57-
for user in input_users:
58-
if user == linear_node:
59-
user.replace_input_with(input_node, squeeze_node)
25+
for node in graph_module.graph.nodes:
26+
if node.target != self.linear:
27+
continue
28+
29+
linear_node = node
30+
input_node = linear_node.args[0]
31+
# Since QNN has no keep dims for linear op, we will need to add squeeze and unsqueeze around linear node
32+
# TODO: Find a more general conditional statement.
33+
linear_output = linear_node.meta["val"]
34+
if linear_output.dim() >= 3:
35+
with graph_module.graph.inserting_after(input_node):
36+
input_users = list(input_node.users.keys())
37+
input_tensor = input_node.meta["val"]
38+
squeeze_dim = (-1, input_tensor.shape[-1])
39+
squeeze_node = graph_module.graph.create_node(
40+
"call_function",
41+
self.view_copy,
42+
(
43+
input_node,
44+
squeeze_dim,
45+
),
46+
)
47+
# meta needs to be copied elementwisely for fake-tensor
48+
# to be updated correctly and not affect meta of input_node
49+
for k, v in input_node.meta.items():
50+
squeeze_node.meta[k] = v
51+
squeeze_node.meta["val"] = input_tensor.reshape(squeeze_dim)
52+
for user in input_users:
53+
if user == linear_node:
54+
user.replace_input_with(input_node, squeeze_node)
6055

61-
with graph_module.graph.inserting_after(linear_node):
62-
output_users = list(linear_node.users.keys())
63-
unsqueeze_dim = linear_output.shape
64-
unsqueeze_node = graph_module.graph.create_node(
65-
"call_function",
66-
self.view_copy,
67-
(
68-
linear_node,
69-
unsqueeze_dim,
70-
),
71-
)
72-
# meta needs to be copied elementwisely for fake-tensor
73-
# to be updated correctly and not affect meta of unsqueeze_node
74-
for k, v in linear_node.meta.items():
75-
unsqueeze_node.meta[k] = v
76-
# update linear node's shape
77-
linear_node.meta["val"] = linear_output.reshape(
78-
(squeeze_node.meta["val"].shape[0], linear_output.shape[-1])
79-
)
80-
for user in output_users:
81-
user.replace_input_with(linear_node, unsqueeze_node)
56+
with graph_module.graph.inserting_after(linear_node):
57+
output_users = list(linear_node.users.keys())
58+
unsqueeze_dim = linear_output.shape
59+
unsqueeze_node = graph_module.graph.create_node(
60+
"call_function",
61+
self.view_copy,
62+
(
63+
linear_node,
64+
unsqueeze_dim,
65+
),
66+
)
67+
# meta needs to be copied elementwisely for fake-tensor
68+
# to be updated correctly and not affect meta of unsqueeze_node
69+
for k, v in linear_node.meta.items():
70+
unsqueeze_node.meta[k] = v
71+
# update linear node's shape
72+
linear_node.meta["val"] = linear_output.reshape(
73+
(squeeze_node.meta["val"].shape[0], linear_output.shape[-1])
74+
)
75+
for user in output_users:
76+
user.replace_input_with(linear_node, unsqueeze_node)
8277

8378
def call(self, graph_module: torch.fx.GraphModule):
8479
self._fixed_keep_dim(graph_module)

0 commit comments

Comments
 (0)