Skip to content

Commit 275adee

Browse files
Arm backend: Use correct name when indexing state_dict (#12954)
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 9f8c2f6 commit 275adee

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import torch
99
from executorch.exir.pass_base import ExportPass, PassResult
10+
from torch._export.utils import is_buffer, is_param
1011

1112

1213
class UnsqueezeScalarPlaceholdersPass(ExportPass):
@@ -19,23 +20,27 @@ def __init__(self, exported_program):
1920
self.exported_program = exported_program
2021
super().__init__()
2122

22-
def _is_inputs_to_buffers_or_parameters(self, node):
23-
return (
24-
node.name in self.exported_program.graph_signature.inputs_to_buffers
25-
or node.name in self.exported_program.graph_signature.inputs_to_parameters
26-
)
27-
2823
def call(self, graph_module: torch.fx.GraphModule):
2924
for node in graph_module.graph.nodes:
3025
if node.op != "placeholder":
3126
continue
3227
rank = node.meta["val"].dim()
3328
if rank == 0:
34-
if not self._is_inputs_to_buffers_or_parameters(node):
29+
if is_buffer(self.exported_program, node):
30+
name = self.exported_program.graph_signature.inputs_to_buffers[
31+
node.name
32+
]
33+
elif is_param(self.exported_program, node):
34+
name = self.exported_program.graph_signature.inputs_to_parameters[
35+
node.name
36+
]
37+
else:
3538
continue
36-
tensor = self.exported_program.state_dict[node.name]
39+
40+
tensor = self.exported_program.state_dict[name]
41+
3742
if tensor.dim() == 0:
38-
self.exported_program.state_dict[node.name] = tensor.unsqueeze(0)
43+
self.exported_program.state_dict[name] = tensor.unsqueeze(0)
3944
node.meta["val"] = node.meta["val"].fake_mode.from_tensor(
4045
tensor.unsqueeze(0), static_shapes=True
4146
)
@@ -53,6 +58,9 @@ def ensures(self, graph_module: torch.fx.GraphModule):
5358
if node.op == "placeholder":
5459
rank = node.meta["val"].dim()
5560
if rank == 0:
56-
if not self._is_inputs_to_buffers_or_parameters(node):
61+
if not (
62+
is_buffer(self.exported_program, node)
63+
or is_param(self.exported_program, node)
64+
):
5765
continue
5866
raise ValueError("Placeholders of rank 0 are not supported!")

0 commit comments

Comments
 (0)