Skip to content

Commit 2944184

Browse files
authored
Merge branch 'gh/ahmtox/40/orig' into gh/ahmtox/41/orig
2 parents 9c28a88 + 4218a8f commit 2944184

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import torch
99
from executorch.exir.pass_base import ExportPass, PassResult
10+
from torch._export.utils import is_buffer, is_param
1011

1112

1213
class UnsqueezeScalarPlaceholdersPass(ExportPass):
@@ -19,23 +20,27 @@ def __init__(self, exported_program):
1920
self.exported_program = exported_program
2021
super().__init__()
2122

22-
def _is_inputs_to_buffers_or_parameters(self, node):
23-
return (
24-
node.name in self.exported_program.graph_signature.inputs_to_buffers
25-
or node.name in self.exported_program.graph_signature.inputs_to_parameters
26-
)
27-
2823
def call(self, graph_module: torch.fx.GraphModule):
2924
for node in graph_module.graph.nodes:
3025
if node.op != "placeholder":
3126
continue
3227
rank = node.meta["val"].dim()
3328
if rank == 0:
34-
if not self._is_inputs_to_buffers_or_parameters(node):
29+
if is_buffer(self.exported_program, node):
30+
name = self.exported_program.graph_signature.inputs_to_buffers[
31+
node.name
32+
]
33+
elif is_param(self.exported_program, node):
34+
name = self.exported_program.graph_signature.inputs_to_parameters[
35+
node.name
36+
]
37+
else:
3538
continue
36-
tensor = self.exported_program.state_dict[node.name]
39+
40+
tensor = self.exported_program.state_dict[name]
41+
3742
if tensor.dim() == 0:
38-
self.exported_program.state_dict[node.name] = tensor.unsqueeze(0)
43+
self.exported_program.state_dict[name] = tensor.unsqueeze(0)
3944
node.meta["val"] = node.meta["val"].fake_mode.from_tensor(
4045
tensor.unsqueeze(0), static_shapes=True
4146
)
@@ -53,6 +58,9 @@ def ensures(self, graph_module: torch.fx.GraphModule):
5358
if node.op == "placeholder":
5459
rank = node.meta["val"].dim()
5560
if rank == 0:
56-
if not self._is_inputs_to_buffers_or_parameters(node):
61+
if not (
62+
is_buffer(self.exported_program, node)
63+
or is_param(self.exported_program, node)
64+
):
5765
continue
5866
raise ValueError("Placeholders of rank 0 are not supported!")

examples/arm/aot_arm_compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
343343
"vgf",
344344
"TOSA-0.80+BI",
345345
"TOSA-1.0+INT",
346+
"TOSA-1.0+FP",
346347
]
347348

348349

0 commit comments

Comments
 (0)