File tree Expand file tree Collapse file tree 5 files changed +60
-6
lines changed Expand file tree Collapse file tree 5 files changed +60
-6
lines changed Original file line number Diff line number Diff line change 3232 CompileSpec ,
3333 PreprocessResult ,
3434)
35+
36+ from executorch .exir .passes .memory_format_ops_pass import DimOrderOpsRevertPass
37+ from executorch .exir .program ._program import _transform
3538from torch .export .exported_program import ExportedProgram
3639
3740FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -83,6 +86,9 @@ def preprocess(
8386 # FlatBuffer graph, process the `output` nodes and add their id to
8487 # the `output_ids` array in the schema.
8588
89+ # TODO: Remove this once we have a better support for the dim-order ops.
90+ edge_program = _transform (edge_program , DimOrderOpsRevertPass ())
91+
8692 mps_graph = MPSGraph (
8793 version = "0" ,
8894 mps_nodes = [],
Original file line number Diff line number Diff line change @@ -79,6 +79,25 @@ def define_node(
7979 )
8080
8181
82+ @register_node_visitor
83+ class ToDimOrderEmptyVisitor (NodeVisitor ):
84+ target = ["dim_order_ops._empty_dim_order.default" ]
85+
86+ def __init__ (self , * args ) -> None :
87+ super ().__init__ (* args )
88+
89+ def define_node (
90+ self ,
91+ node : torch .fx .Node ,
92+ mps_graph : MPSGraph ,
93+ ) -> None :
94+ # We should never get here, because DimOrderOpsRevertPass replaces this with an aten.empty.memory_format op
95+ # But if we do, we can't handle it ATM, so raise an exception
96+ raise NotImplementedError (
97+ "dim_order_ops._empty_dim_order.default is not supported yet"
98+ )
99+
100+
82101@register_node_visitor
83102class FullLikeVisitor (NodeVisitor ):
84103 target = "aten.full_like.default"
Original file line number Diff line number Diff line change @@ -33,3 +33,22 @@ def define_node(
3333 )
3434 input_id = self .define_tensor (get_input_node (node , 0 ), mps_graph )
3535 self .tensor_to_id [node ] = input_id
36+
37+
38+ @register_node_visitor
39+ class ToDimOrderCopyVisitor (NodeVisitor ):
40+ target = ["dim_order_ops._to_dim_order_copy.default" ]
41+
42+ def __init__ (self , * args ) -> None :
43+ super ().__init__ (* args )
44+
45+ def define_node (
46+ self ,
47+ node : torch .fx .Node ,
48+ mps_graph : MPSGraph ,
49+ ) -> None :
50+ # We should never get here, because DimOrderOpsRevertPass replaces this with an aten._to_copy op
51+ # But if we do, we can't handle it ATM, so raise an exception
52+ raise NotImplementedError (
53+ "dim_order_ops._to_dim_order_copy.default is not supported yet"
54+ )
Original file line number Diff line number Diff line change @@ -1829,6 +1829,21 @@ def forward(self, x):
18291829 Clone (), model_inputs , func_name = inspect .stack ()[0 ].function [5 :]
18301830 )
18311831
1832+ def test_mps_backend_to_copy (self ):
1833+ class Copy (torch .nn .Module ):
1834+ def forward (self , x ):
1835+ return (
1836+ torch .ops .aten ._to_copy .default (
1837+ x + 2 , memory_format = torch .contiguous_format
1838+ )
1839+ + x
1840+ )
1841+
1842+ model_inputs = (torch .randn (1 , 3 , 3 ),)
1843+ self .lower_and_test_with_partitioner (
1844+ Copy (), model_inputs , func_name = inspect .stack ()[0 ].function [5 :]
1845+ )
1846+
18321847 def test_mps_backend_floor (self ):
18331848 class Floor (torch .nn .Module ):
18341849 def forward (self , x ):
Original file line number Diff line number Diff line change 2626
2727# Config for Capturing the weights, will be moved in the future
2828
29- # TODO(T182928844): Delegate dim order op to backend.
30- _EDGE_COMPILE_CONFIG = exir .EdgeCompileConfig (
31- _check_ir_validity = False , _skip_dim_order = True
32- )
29+ _EDGE_COMPILE_CONFIG = exir .EdgeCompileConfig (_check_ir_validity = False )
3330
3431
3532class ansi_colors :
@@ -219,7 +216,6 @@ def lower_module_and_test_output(
219216 dynamic_shapes = dynamic_shapes ,
220217 edge_compile_config = EdgeCompileConfig (
221218 _check_ir_validity = False ,
222- _skip_dim_order = True , # TODO(T182928844): Delegate dim order op to backend.
223219 ),
224220 )
225221
@@ -253,7 +249,6 @@ def lower_module_and_test_output(
253249 ),
254250 compile_config = exir .EdgeCompileConfig (
255251 _check_ir_validity = False ,
256- _skip_dim_order = True , # TODO(T182928844): Delegate dim order op to backend.
257252 ),
258253 ).to_executorch (
259254 config = ExecutorchBackendConfig (extract_delegate_segments = False )
You can’t perform that action at this time.
0 commit comments