File tree Expand file tree Collapse file tree 5 files changed +60
-6
lines changed Expand file tree Collapse file tree 5 files changed +60
-6
lines changed Original file line number Diff line number Diff line change 3232    CompileSpec ,
3333    PreprocessResult ,
3434)
35+ 
36+ from  executorch .exir .passes .memory_format_ops_pass  import  DimOrderOpsRevertPass 
37+ from  executorch .exir .program ._program  import  _transform 
3538from  torch .export .exported_program  import  ExportedProgram 
3639
3740FORMAT  =  "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" 
@@ -83,6 +86,9 @@ def preprocess(
8386        #    FlatBuffer graph, process the `output` nodes and add their id to 
8487        #    the `output_ids` array in the schema. 
8588
89+         # TODO: Remove this once we have a better support for the dim-order ops. 
90+         edge_program  =  _transform (edge_program , DimOrderOpsRevertPass ())
91+ 
8692        mps_graph  =  MPSGraph (
8793            version = "0" ,
8894            mps_nodes = [],
Original file line number Diff line number Diff line change @@ -79,6 +79,25 @@ def define_node(
7979        )
8080
8181
82+ @register_node_visitor  
83+ class  ToDimOrderEmptyVisitor (NodeVisitor ):
84+     target  =  ["dim_order_ops._empty_dim_order.default" ]
85+ 
86+     def  __init__ (self , * args ) ->  None :
87+         super ().__init__ (* args )
88+ 
89+     def  define_node (
90+         self ,
91+         node : torch .fx .Node ,
92+         mps_graph : MPSGraph ,
93+     ) ->  None :
94+         # We should never get here, because DimOrderOpsRevertPass replaces this with an aten.empty.memory_format op 
95+         # But if we do, we can't handle it ATM, so raise an exception 
96+         raise  NotImplementedError (
97+             "dim_order_ops._empty_dim_order.default is not supported yet" 
98+         )
99+ 
100+ 
82101@register_node_visitor  
83102class  FullLikeVisitor (NodeVisitor ):
84103    target  =  "aten.full_like.default" 
Original file line number Diff line number Diff line change @@ -33,3 +33,22 @@ def define_node(
3333                )
3434        input_id  =  self .define_tensor (get_input_node (node , 0 ), mps_graph )
3535        self .tensor_to_id [node ] =  input_id 
36+ 
37+ 
38+ @register_node_visitor  
39+ class  ToDimOrderCopyVisitor (NodeVisitor ):
40+     target  =  ["dim_order_ops._to_dim_order_copy.default" ]
41+ 
42+     def  __init__ (self , * args ) ->  None :
43+         super ().__init__ (* args )
44+ 
45+     def  define_node (
46+         self ,
47+         node : torch .fx .Node ,
48+         mps_graph : MPSGraph ,
49+     ) ->  None :
50+         # We should never get here, because DimOrderOpsRevertPass replaces this with an aten._to_copy op 
51+         # But if we do, we can't handle it ATM, so raise an exception 
52+         raise  NotImplementedError (
53+             "dim_order_ops._to_dim_order_copy.default is not supported yet" 
54+         )
Original file line number Diff line number Diff line change @@ -1829,6 +1829,21 @@ def forward(self, x):
18291829            Clone (), model_inputs , func_name = inspect .stack ()[0 ].function [5 :]
18301830        )
18311831
1832+     def  test_mps_backend_to_copy (self ):
1833+         class  Copy (torch .nn .Module ):
1834+             def  forward (self , x ):
1835+                 return  (
1836+                     torch .ops .aten ._to_copy .default (
1837+                         x  +  2 , memory_format = torch .contiguous_format 
1838+                     )
1839+                     +  x 
1840+                 )
1841+ 
1842+         model_inputs  =  (torch .randn (1 , 3 , 3 ),)
1843+         self .lower_and_test_with_partitioner (
1844+             Copy (), model_inputs , func_name = inspect .stack ()[0 ].function [5 :]
1845+         )
1846+ 
18321847    def  test_mps_backend_floor (self ):
18331848        class  Floor (torch .nn .Module ):
18341849            def  forward (self , x ):
Original file line number Diff line number Diff line change 2626
2727# Config for Capturing the weights, will be moved in the future 
2828
29- # TODO(T182928844): Delegate dim order op to backend. 
30- _EDGE_COMPILE_CONFIG  =  exir .EdgeCompileConfig (
31-     _check_ir_validity = False , _skip_dim_order = True 
32- )
29+ _EDGE_COMPILE_CONFIG  =  exir .EdgeCompileConfig (_check_ir_validity = False )
3330
3431
3532class  ansi_colors :
@@ -219,7 +216,6 @@ def lower_module_and_test_output(
219216            dynamic_shapes = dynamic_shapes ,
220217            edge_compile_config = EdgeCompileConfig (
221218                _check_ir_validity = False ,
222-                 _skip_dim_order = True ,  # TODO(T182928844): Delegate dim order op to backend. 
223219            ),
224220        )
225221
@@ -253,7 +249,6 @@ def lower_module_and_test_output(
253249                ),
254250                compile_config = exir .EdgeCompileConfig (
255251                    _check_ir_validity = False ,
256-                     _skip_dim_order = True ,  # TODO(T182928844): Delegate dim order op to backend. 
257252                ),
258253            ).to_executorch (
259254                config = ExecutorchBackendConfig (extract_delegate_segments = False )
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments