Skip to content

Commit 8cd0784

Browse files
committed
[mps] Disable dim_order for mps_examples
This is a temporary fix to get CI green.
1 parent fb1cc93 commit 8cd0784

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

examples/apple/mps/scripts/mps_example.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,12 +165,16 @@ def get_model_config(args):
165165
inputs_copy = tuple(inputs_copy)
166166

167167
# pre-autograd export. eventually this will become torch.export
168+
# TODO: revert _skip_dim_order=True once the mps issue is fixed.
169+
skip_dim_order = True
168170
with torch.no_grad():
169171
model = torch.export.export_for_training(model, example_inputs).module()
170172
edge: EdgeProgramManager = export_to_edge(
171173
model,
172174
example_inputs,
173-
edge_compile_config=EdgeCompileConfig(_check_ir_validity=False),
175+
edge_compile_config=EdgeCompileConfig(
176+
_check_ir_validity=False, _skip_dim_order=skip_dim_order
177+
),
174178
)
175179

176180
edge_program_manager_copy = copy.deepcopy(edge)
@@ -192,7 +196,9 @@ def get_model_config(args):
192196
executorch_program: ExecutorchProgramManager = export_to_edge(
193197
lowered_module,
194198
example_inputs,
195-
edge_compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
199+
edge_compile_config=exir.EdgeCompileConfig(
200+
_check_ir_validity=False, _skip_dim_order=skip_dim_order
201+
),
196202
).to_executorch(config=ExecutorchBackendConfig(extract_delegate_segments=False))
197203

198204
dtype = "float16" if args.use_fp16 else "float32"

0 commit comments

Comments
 (0)