Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/cadence/aot/memory_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
create_cadence_pass_filter,
register_cadence_pass,
)
from executorch.backends.cadence.aot.utils import get_shape, is_node_in_flattened_output
from executorch.backends.cadence.aot.compiler_utils import get_shape, is_node_in_flattened_output
from executorch.exir import memory
from executorch.exir.pass_manager import PassManager
from executorch.exir.tensor import num_bytes_from_shape_and_dtype, TensorSpec
Expand Down
43 changes: 0 additions & 43 deletions backends/cadence/aot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,49 +124,6 @@ def get_ops_count(graph_module: torch.fx.GraphModule) -> Dict[str, int]:
return freq


# Return the output node of the graph
def get_output_node(graph: torch.fx.Graph) -> torch.fx.Node:
assert graph is not None, "Cannot get output of an empty graph"
output_node = next(iter(reversed(graph.nodes)))
assert (
output_node and output_node.op == "output" and len(output_node.args) == 1
), "Failed to find output node"
return output_node


# Return true if the node is part of the flattened output
def is_node_in_flattened_output(graph: torch.fx.Graph, node: torch.fx.Node) -> bool:
output_node = get_output_node(graph)
return node in tree_flatten(output_node.args[0])[0]


# Return the shape of the incoming node.
def get_shape(
graph_module: torch.fx.GraphModule, node: torch.fx.Node
) -> Union[torch.Size, None]:
"""
Return the shape of the tensor correspnding to node. If the node has a
tensor spec, return the shape from the metadata. If the node is a param,
return it shape. Otherwise return None.
"""
try:
# Case 1. node is a scalar
if isinstance(node, (float, int, bool)):
return torch.Size([1])
# Case 2. node has TensorSpec metadata
fake_tensor = node.meta.get("val")
if fake_tensor is not None:
return fake_tensor.shape
# Case 3. node holds a param
if node.op == "get_attr":
attr_node = getattr(graph_module, node.target)
return attr_node.shape
# Default: return None
return None
except RuntimeError:
return None


# Print the ops and how many times they occur multiple graph modules:
# from export, from to_edge, and from final. Print the available
# implementations for each op, and error out if the op is not supported.
Expand Down
Loading