Skip to content

Commit fc50a13

Browse files
digantdesaifacebook-github-bot
authored andcommitted
Add dim_order compat support (#7420)
Summary: Pull Request resolved: #7420 Differential Revision: D67542995
1 parent 82763a9 commit fc50a13

File tree

5 files changed

+49
-1
lines changed

5 files changed

+49
-1
lines changed

backends/apple/mps/mps_preprocess.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
CompileSpec,
3333
PreprocessResult,
3434
)
35+
36+
from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass
37+
from executorch.exir.program._program import _transform
3538
from torch.export.exported_program import ExportedProgram
3639

3740
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -83,6 +86,9 @@ def preprocess(
8386
# FlatBuffer graph, process the `output` nodes and add their id to
8487
# the `output_ids` array in the schema.
8588

89+
# TODO: Remove this once we have a better support for the dim-order ops.
90+
edge_program = _transform(edge_program, DimOrderOpsRevertPass())
91+
8692
mps_graph = MPSGraph(
8793
version="0",
8894
mps_nodes=[],

backends/apple/mps/operators/constant_ops.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,22 @@ def define_node(
7878
)
7979
)
8080

81+
@register_node_visitor
82+
class ToDimOrderEmptyVisitor(NodeVisitor):
83+
target = ["dim_order_ops._empty_dim_order.default"]
84+
85+
def __init__(self, *args) -> None:
86+
super().__init__(*args)
87+
88+
def define_node(
89+
self,
90+
node: torch.fx.Node,
91+
mps_graph: MPSGraph,
92+
) -> None:
93+
# We should never get here, because DimOrderOpsRevertPass replaces this with an aten.empty.memory_format op
94+
# But if we do, we can't handle it ATM, so raise an exception
95+
raise NotImplementedError("dim_order_ops._empty_dim_order.default is not supported yet")
96+
8197

8298
@register_node_visitor
8399
class FullLikeVisitor(NodeVisitor):

backends/apple/mps/operators/op_clone.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,19 @@ def define_node(
3333
)
3434
input_id = self.define_tensor(get_input_node(node, 0), mps_graph)
3535
self.tensor_to_id[node] = input_id
36+
37+
@register_node_visitor
38+
class ToDimOrderCopyVisitor(NodeVisitor):
39+
target = ["dim_order_ops._to_dim_order_copy.default"]
40+
41+
def __init__(self, *args) -> None:
42+
super().__init__(*args)
43+
44+
def define_node(
45+
self,
46+
node: torch.fx.Node,
47+
mps_graph: MPSGraph,
48+
) -> None:
49+
# We should never get here, because DimOrderOpsRevertPass replaces this with an aten._to_copy op
50+
# But if we do, we can't handle it ATM, so raise an exception
51+
raise NotImplementedError("dim_order_ops._to_dim_order_copy.default is not supported yet")

backends/apple/mps/test/test_mps.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1829,6 +1829,16 @@ def forward(self, x):
18291829
Clone(), model_inputs, func_name=inspect.stack()[0].function[5:]
18301830
)
18311831

1832+
def test_mps_backend_to_copy(self):
1833+
class Copy(torch.nn.Module):
1834+
def forward(self, x):
1835+
return torch.ops.aten._to_copy.default(x + 2, memory_format=torch.contiguous_format) + x
1836+
1837+
model_inputs = (torch.randn(1, 3, 3),)
1838+
self.lower_and_test_with_partitioner(
1839+
Copy(), model_inputs, func_name=inspect.stack()[0].function[5:]
1840+
)
1841+
18321842
def test_mps_backend_floor(self):
18331843
class Floor(torch.nn.Module):
18341844
def forward(self, x):

backends/apple/mps/test/test_mps_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def lower_module_and_test_output(
219219
dynamic_shapes=dynamic_shapes,
220220
edge_compile_config=EdgeCompileConfig(
221221
_check_ir_validity=False,
222-
_skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend.
222+
_skip_dim_order=False, # TODO(T182928844): Delegate dim order op to backend.
223223
),
224224
)
225225

0 commit comments

Comments
 (0)