diff --git a/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py b/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py index 0276e65a081..ccae9b503cf 100644 --- a/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py +++ b/backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py @@ -7,6 +7,7 @@ import torch from executorch.exir.pass_base import ExportPass, PassResult +from torch._export.utils import is_buffer, is_param class UnsqueezeScalarPlaceholdersPass(ExportPass): @@ -19,23 +20,27 @@ def __init__(self, exported_program): self.exported_program = exported_program super().__init__() - def _is_inputs_to_buffers_or_parameters(self, node): - return ( - node.name in self.exported_program.graph_signature.inputs_to_buffers - or node.name in self.exported_program.graph_signature.inputs_to_parameters - ) - def call(self, graph_module: torch.fx.GraphModule): for node in graph_module.graph.nodes: if node.op != "placeholder": continue rank = node.meta["val"].dim() if rank == 0: - if not self._is_inputs_to_buffers_or_parameters(node): + if is_buffer(self.exported_program, node): + name = self.exported_program.graph_signature.inputs_to_buffers[ + node.name + ] + elif is_param(self.exported_program, node): + name = self.exported_program.graph_signature.inputs_to_parameters[ + node.name + ] + else: continue - tensor = self.exported_program.state_dict[node.name] + + tensor = self.exported_program.state_dict[name] + if tensor.dim() == 0: - self.exported_program.state_dict[node.name] = tensor.unsqueeze(0) + self.exported_program.state_dict[name] = tensor.unsqueeze(0) node.meta["val"] = node.meta["val"].fake_mode.from_tensor( tensor.unsqueeze(0), static_shapes=True ) @@ -53,6 +58,9 @@ def ensures(self, graph_module: torch.fx.GraphModule): if node.op == "placeholder": rank = node.meta["val"].dim() if rank == 0: - if not self._is_inputs_to_buffers_or_parameters(node): + if not ( + is_buffer(self.exported_program, node) + or is_param(self.exported_program, node) + ): continue raise ValueError("Placeholders of rank 0 are not supported!")