Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
from executorch.exir.pass_base import ExportPass, PassResult
from torch._export.utils import is_buffer, is_param


class UnsqueezeScalarPlaceholdersPass(ExportPass):
Expand All @@ -19,23 +20,27 @@ def __init__(self, exported_program):
self.exported_program = exported_program
super().__init__()

def _is_inputs_to_buffers_or_parameters(self, node):
return (
node.name in self.exported_program.graph_signature.inputs_to_buffers
or node.name in self.exported_program.graph_signature.inputs_to_parameters
)

def call(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
if node.op != "placeholder":
continue
rank = node.meta["val"].dim()
if rank == 0:
if not self._is_inputs_to_buffers_or_parameters(node):
if is_buffer(self.exported_program, node):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit make a util def get_name(self, node)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Local to this file? Seems a little redundant to me, I suppose that would look something like this

def get_name(self, node):
     if is_buffer(self.exported_program, node):
         return self.exported_program.graph_signature.inputs_to_buffers[node.name]
     # Reapeat for param etc.
     ...

And this is done 2 times only.

I think it would belong in torch._export.utils like is_buffer etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even local would make it cleaner. Feel free to put up a PR against _export.utils if you feel like others are doing this or should do it this way.

name = self.exported_program.graph_signature.inputs_to_buffers[
node.name
]
elif is_param(self.exported_program, node):
name = self.exported_program.graph_signature.inputs_to_parameters[
node.name
]
else:
continue
tensor = self.exported_program.state_dict[node.name]

tensor = self.exported_program.state_dict[name]

if tensor.dim() == 0:
self.exported_program.state_dict[node.name] = tensor.unsqueeze(0)
self.exported_program.state_dict[name] = tensor.unsqueeze(0)
node.meta["val"] = node.meta["val"].fake_mode.from_tensor(
tensor.unsqueeze(0), static_shapes=True
)
Expand All @@ -53,6 +58,9 @@ def ensures(self, graph_module: torch.fx.GraphModule):
if node.op == "placeholder":
rank = node.meta["val"].dim()
if rank == 0:
if not self._is_inputs_to_buffers_or_parameters(node):
if not (
is_buffer(self.exported_program, node)
or is_param(self.exported_program, node)
):
continue
raise ValueError("Placeholders of rank 0 are not supported!")
Loading