Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/vulkan/_passes/remove_redundant_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class RemoveRedundantOpsTransform(ExportPass):
exir_ops.edge.aten.alias.default,
exir_ops.edge.aten.lift_fresh_copy.default,
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
exir_ops.edge.dim_order_ops._clone_dim_order.default
}

def __init__(self) -> None:
Expand Down
28 changes: 28 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,34 @@ def check_dim_order_copy_node(node: torch.fx.Node) -> bool:
)


@update_features(exir_ops.edge.dim_order_ops._clone_dim_order.default)
def register_clone_dim_order_op():
# Similar to to_dim_order_copy, _clone_dim_order can be removed as long as the
# operator is not changing the dtype, i.e. the operator call is modifying the dim
# order only. Therefore, check that the input and output dtypes are the same, if so
# the operator is safe to remove.
def check_clone_dim_order_node(node: torch.fx.Node) -> bool:
in_arg = node.args[0]
if not isinstance(in_arg, torch.fx.Node):
return False

in_tensor = in_arg.meta.get("val", None)
out_tensor = node.meta.get("val", None)

if in_tensor.dtype != out_tensor.dtype:
return False

return True

return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
supports_resize=True,
are_node_inputs_supported_fn=check_clone_dim_order_node,
)




@update_features(
[
exir_ops.edge.aten.bmm.default,
Expand Down
12 changes: 8 additions & 4 deletions exir/passes/constant_prop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,9 @@ def create_constant_nodes_and_return_specs(
)
return name_to_spec_dict


def _update_output_node_and_specs(exported_program: ExportedProgram) -> None:
# add _skip_dim_order to ensure the introduced correct clone node for different dim order schema
# TODO(gasoonjia): only relying on _clone_dim_order once we remove _skip_dim_order option in the EdgeCompileConfig
def _update_output_node_and_specs(exported_program: ExportedProgram, _skip_dim_order: bool) -> None:
"""
Update the output node and output specs in the exported program.
In case a constant node is used as output, we replace it with a clone of the constant node.
Expand All @@ -307,14 +308,16 @@ def _update_output_node_and_specs(exported_program: ExportedProgram) -> None:
output_specs = exported_program.graph_signature.output_specs
assert len(output_nodes) == len(output_specs)

clone_op = exir_ops.edge.aten.clone.default if _skip_dim_order else exir_ops.edge.dim_order_ops._clone_dim_order.default

for i in range(len(output_specs)):
out_node = output_nodes[i]
if out_node not in updated_constant_placeholders:
continue

with exported_program.graph.inserting_after(out_node):
new_node = exported_program.graph.call_function(
exir_ops.edge.aten.clone.default, (out_node,)
clone_op, (out_node,)
)
assert "val" in out_node.meta
new_node.meta["val"] = out_node.meta["val"]
Expand All @@ -329,6 +332,7 @@ def _update_output_node_and_specs(exported_program: ExportedProgram) -> None:
def constant_prop_pass(
exported_program: ExportedProgram,
custom_skip_targets: Optional[set[EdgeOpOverload]] = None,
_skip_dim_order: bool = True,
) -> ExportedProgram:
"""
This pass is for constant propagation for Exported Program with lifted parameters,
Expand Down Expand Up @@ -376,7 +380,7 @@ def constant_prop_pass(
new_input_specs.append(name_to_spec_dict[node.name])
exported_program.graph_signature.input_specs = new_input_specs

_update_output_node_and_specs(exported_program)
_update_output_node_and_specs(exported_program, _skip_dim_order=_skip_dim_order)

# Cleanup the graph.
exported_program.graph.eliminate_dead_code()
Expand Down
5 changes: 1 addition & 4 deletions exir/passes/memory_format_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)

# TODO - these passes are too specialized on a single to_copy op.
# We should be able to replace (or revert) any of the dim_order ops in the future.


class MemoryFormatOpsPass(ExportPass):
"""
Expand All @@ -43,7 +40,7 @@ def call_operator(self, op, args, kwargs, meta):
# new kwargs with dim_order, and no memory_format for the new op
nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable

# get the "to" memory format for the EdgeOp
# get the target memory format for the EdgeOp
mem_format = nkwargs.pop("memory_format", torch.contiguous_format)

# can always get the shape, assuming rank is specialized
Expand Down
2 changes: 1 addition & 1 deletion exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,7 +1192,7 @@ def forward(self) -> torch.Tensor:
)

edge._edge_programs["forward"] = constant_prop_pass(
edge.exported_program("forward")
edge.exported_program("forward"), _skip_dim_order=False
)

# Check (c_lifted_tensor_*) nodes are all replaced by _prop_tensor_constant.
Expand Down
Loading