File tree Expand file tree Collapse file tree 1 file changed +8
-2
lines changed
examples/apple/mps/scripts Expand file tree Collapse file tree 1 file changed +8
-2
lines changed Original file line number Diff line number Diff line change @@ -165,12 +165,16 @@ def get_model_config(args):
165165 inputs_copy = tuple (inputs_copy )
166166
167167 # pre-autograd export. eventually this will become torch.export
168+ # TODO: revert _skip_dim_order=True once the mps issue is fixed.
169+ skip_dim_order = True
168170 with torch .no_grad ():
169171 model = torch .export .export_for_training (model , example_inputs ).module ()
170172 edge : EdgeProgramManager = export_to_edge (
171173 model ,
172174 example_inputs ,
173- edge_compile_config = EdgeCompileConfig (_check_ir_validity = False ),
175+ edge_compile_config = EdgeCompileConfig (
176+ _check_ir_validity = False , _skip_dim_order = skip_dim_order
177+ ),
174178 )
175179
176180 edge_program_manager_copy = copy .deepcopy (edge )
@@ -192,7 +196,9 @@ def get_model_config(args):
192196 executorch_program : ExecutorchProgramManager = export_to_edge (
193197 lowered_module ,
194198 example_inputs ,
195- edge_compile_config = exir .EdgeCompileConfig (_check_ir_validity = False ),
199+ edge_compile_config = exir .EdgeCompileConfig (
200+ _check_ir_validity = False , _skip_dim_order = skip_dim_order
201+ ),
196202 ).to_executorch (config = ExecutorchBackendConfig (extract_delegate_segments = False ))
197203
198204 dtype = "float16" if args .use_fp16 else "float32"
You can’t perform that action at this time.
0 commit comments