Skip to content

Commit 245630a

Browse files
forward fix for D81697327 (#14113)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #14088 by @Gasoonjia ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/39/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/39/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/39/orig @diff-train-skip-merge Co-authored-by: gasoonjia <[email protected]>
1 parent 59b9707 commit 245630a

File tree

5 files changed

+43
-10
lines changed

5 files changed

+43
-10
lines changed

backends/vulkan/_passes/remove_redundant_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class RemoveRedundantOpsTransform(ExportPass):
3030
exir_ops.edge.aten.alias.default,
3131
exir_ops.edge.aten.lift_fresh_copy.default,
3232
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
33+
exir_ops.edge.dim_order_ops._clone_dim_order.default,
3334
}
3435

3536
def __init__(self) -> None:

backends/vulkan/op_registry.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,32 @@ def check_dim_order_copy_node(node: torch.fx.Node) -> bool:
301301
)
302302

303303

304+
@update_features(exir_ops.edge.dim_order_ops._clone_dim_order.default)
305+
def register_clone_dim_order_op():
306+
# Similar to to_dim_order_copy, _clone_dim_order can be removed as long as the
307+
# operator is not changing the dtype, i.e. the operator call is modifying the dim
308+
# order only. Therefore, check that the input and output dtypes are the same, if so
309+
# the operator is safe to remove.
310+
def check_clone_dim_order_node(node: torch.fx.Node) -> bool:
311+
in_arg = node.args[0]
312+
if not isinstance(in_arg, torch.fx.Node):
313+
return False
314+
315+
in_tensor = in_arg.meta.get("val", None)
316+
out_tensor = node.meta.get("val", None)
317+
318+
if in_tensor.dtype != out_tensor.dtype:
319+
return False
320+
321+
return True
322+
323+
return OpFeatures(
324+
inputs_storage=utils.ANY_STORAGE,
325+
supports_resize=True,
326+
are_node_inputs_supported_fn=check_clone_dim_order_node,
327+
)
328+
329+
304330
@update_features(
305331
[
306332
exir_ops.edge.aten.bmm.default,

exir/passes/constant_prop_pass.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,11 @@ def create_constant_nodes_and_return_specs(
295295
return name_to_spec_dict
296296

297297

298-
def _update_output_node_and_specs(exported_program: ExportedProgram) -> None:
298+
# add _skip_dim_order to ensure the introduced correct clone node for different dim order schema
299+
# TODO(gasoonjia): only relying on _clone_dim_order once we remove _skip_dim_order option in the EdgeCompileConfig
300+
def _update_output_node_and_specs(
301+
exported_program: ExportedProgram, _skip_dim_order: bool
302+
) -> None:
299303
"""
300304
Update the output node and output specs in the exported program.
301305
In case a constant node is used as output, we replace it with a clone of the constant node.
@@ -307,15 +311,19 @@ def _update_output_node_and_specs(exported_program: ExportedProgram) -> None:
307311
output_specs = exported_program.graph_signature.output_specs
308312
assert len(output_nodes) == len(output_specs)
309313

314+
clone_op = (
315+
exir_ops.edge.aten.clone.default
316+
if _skip_dim_order
317+
else exir_ops.edge.dim_order_ops._clone_dim_order.default
318+
)
319+
310320
for i in range(len(output_specs)):
311321
out_node = output_nodes[i]
312322
if out_node not in updated_constant_placeholders:
313323
continue
314324

315325
with exported_program.graph.inserting_after(out_node):
316-
new_node = exported_program.graph.call_function(
317-
exir_ops.edge.aten.clone.default, (out_node,)
318-
)
326+
new_node = exported_program.graph.call_function(clone_op, (out_node,))
319327
assert "val" in out_node.meta
320328
new_node.meta["val"] = out_node.meta["val"]
321329
output_nodes[i] = new_node
@@ -329,6 +337,7 @@ def _update_output_node_and_specs(exported_program: ExportedProgram) -> None:
329337
def constant_prop_pass(
330338
exported_program: ExportedProgram,
331339
custom_skip_targets: Optional[set[EdgeOpOverload]] = None,
340+
_skip_dim_order: bool = True,
332341
) -> ExportedProgram:
333342
"""
334343
This pass is for constant propagation for Exported Program with lifted parameters,
@@ -376,7 +385,7 @@ def constant_prop_pass(
376385
new_input_specs.append(name_to_spec_dict[node.name])
377386
exported_program.graph_signature.input_specs = new_input_specs
378387

379-
_update_output_node_and_specs(exported_program)
388+
_update_output_node_and_specs(exported_program, _skip_dim_order=_skip_dim_order)
380389

381390
# Cleanup the graph.
382391
exported_program.graph.eliminate_dead_code()

exir/passes/memory_format_ops_pass.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@
1919
logger = logging.getLogger(__file__)
2020
logger.setLevel(logging.INFO)
2121

22-
# TODO - these passes are too specialized on a single to_copy op.
23-
# We should be able to replace (or revert) any of the dim_order ops in the future.
24-
2522

2623
class MemoryFormatOpsPass(ExportPass):
2724
"""
@@ -43,7 +40,7 @@ def call_operator(self, op, args, kwargs, meta):
4340
# new kwargs with dim_order, and no memory_format for the new op
4441
nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable
4542

46-
# get the "to" memory format for the EdgeOp
43+
# get the target memory format for the EdgeOp
4744
mem_format = nkwargs.pop("memory_format", torch.contiguous_format)
4845

4946
# can always get the shape, assuming rank is specialized

exir/tests/test_passes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1192,7 +1192,7 @@ def forward(self) -> torch.Tensor:
11921192
)
11931193

11941194
edge._edge_programs["forward"] = constant_prop_pass(
1195-
edge.exported_program("forward")
1195+
edge.exported_program("forward"), _skip_dim_order=False
11961196
)
11971197

11981198
# Check (c_lifted_tensor_*) nodes are all replaced by _prop_tensor_constant.

0 commit comments

Comments
 (0)