Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions backends/cadence/aot/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,30 +174,53 @@ def nodes_not_adjacent_in_gm(

def get_arg(
node: torch.fx.Node,
arg_index: int,
kwarg_name: str,
*,
default: torch.fx.node.Argument = None,
) -> torch.fx.node.Argument:
"""
Get the arg at arg_index or kwarg with arg_name of the node. If neither is found
return default.
Get the arg with arg_name of the node, returns default value if not set.
"""
if arg_index < len(node.args):
return node.args[arg_index]
elif kwarg_name in node.kwargs:
# Try to get the arg from kwargs first since this is faster
if kwarg_name in node.kwargs:
return node.kwargs[kwarg_name]
else:
return default

# If it's not found in kwargs, try to normalize the args
normalized_args = node.normalized_arguments(
node.graph.owning_module, normalize_to_only_use_kwargs=True
)
if not normalized_args:
raise RuntimeError(
f"get_arg: Node {node} does not support normalization of arguments"
)

return normalized_args.kwargs[kwarg_name]


def set_arg(
node: torch.fx.Node, arg_index: int, kwarg_name: str, value: torch.fx.node.Argument
node: torch.fx.Node, kwarg_name: str, value: torch.fx.node.Argument
) -> None:
"""
Set the arg at arg_index if it exists, otherwise set the kwarg.
Set the node's arg with its name to the given value.
"""
if arg_index < len(node.args):
node.update_arg(arg_index, value)
# Try to set the arg if it is present in kwargs first since this is faster
if kwarg_name in node.kwargs:
node.update_kwarg(kwarg_name, value)
return

# If it's not found in kwargs, try to normalize the args and set the arg
normalized_args = node.normalized_arguments(
node.graph.owning_module, normalize_to_only_use_kwargs=True
)
if not normalized_args:
raise RuntimeError(
f"set_arg: Node {node} does not support normalization of arguments"
)

kwargs = normalized_args.kwargs
if kwarg_name not in kwargs:
raise ValueError(f"set_arg: invalid arg name {kwarg_name} for node {node} used")

idx = list(kwargs.keys()).index(kwarg_name)
if idx < len(node.args):
node.update_arg(idx, value)
else:
node.update_kwarg(kwarg_name, value)
18 changes: 9 additions & 9 deletions backends/cadence/aot/remove_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,17 +779,17 @@ def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
for slice_copy_node in graph_module.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
):
cat_node = cast(Node, get_arg(slice_copy_node, 0, "input"))
slice_dim = cast(int, get_arg(slice_copy_node, 1, "dim", default=0))
start_idx = cast(int, get_arg(slice_copy_node, 2, "start", default=None))
end_idx = cast(int, get_arg(slice_copy_node, 3, "end", default=None))
step = cast(int, get_arg(slice_copy_node, 4, "step", default=1))
cat_node = cast(Node, get_arg(slice_copy_node, "input"))
slice_dim = cast(int, get_arg(slice_copy_node, "dim"))
start_idx = cast(int, get_arg(slice_copy_node, "start"))
end_idx = cast(int, get_arg(slice_copy_node, "end"))
step = cast(int, get_arg(slice_copy_node, "step"))

if cat_node.target != exir_ops.edge.aten.cat.default or step != 1:
continue

# Make sure cat and slice happens on the same dimension.
cat_dim = cast(Node, get_arg(cat_node, 1, "dim", default=0))
cat_dim = cast(Node, get_arg(cat_node, "dim"))
if cat_dim != slice_dim:
continue

Expand All @@ -805,14 +805,14 @@ def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
end_idx += cat_output_shape[cat_dim]

offset = 0
for cat_input_node in cast(List[Node], get_arg(cat_node, 0, "tensors")):
for cat_input_node in cast(List[Node], get_arg(cat_node, "tensors")):
cat_input_shape = cat_input_node.meta["val"].shape

# Check if the slice range overlaps with the cat input range.
if offset <= start_idx and end_idx <= offset + cat_input_shape[cat_dim]:
slice_copy_node.replace_input_with(cat_node, cat_input_node)
set_arg(slice_copy_node, 2, "start", start_idx - offset)
set_arg(slice_copy_node, 3, "end", end_idx - offset)
set_arg(slice_copy_node, "start", start_idx - offset)
set_arg(slice_copy_node, "end", end_idx - offset)
break

offset += cat_input_shape[cat_dim]
Expand Down
Loading