Skip to content

Commit d7a28ed

Browse files
committed
Update on "Enable aoti for preprocess"
Land and update torch nightly pin after: pytorch/pytorch#137063 Test Plan: With pytorch/pytorch#137063: ``` pytest -c /dev/null -v -n auto examples/models/llama3_2_vision/preprocess/ ``` [ghstack-poisoned]
2 parents 74a9cbd + e5d3722 commit d7a28ed

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

examples/models/llama3_2_vision/preprocess/test_preprocess.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,22 +64,30 @@ def initialize_models(resize_to_max_canvas: bool) -> Dict[str, Any]:
6464
strict=False,
6565
)
6666

67-
# aoti_path = torch._inductor.aot_compile(
68-
# exported_model.module(),
69-
# model.get_example_inputs(),
70-
# )
67+
aoti_path = torch._inductor.aot_compile(
68+
exported_model.module(),
69+
model.get_example_inputs(),
70+
)
7171

7272
edge_program = to_edge(
7373
exported_model, compile_config=EdgeCompileConfig(_check_ir_validity=False)
7474
)
7575
executorch_model = edge_program.to_executorch()
7676

77+
# Re-export as ExecuTorch edits the ExportedProgram.
78+
exported_model = torch.export.export(
79+
model.get_eager_model(),
80+
model.get_example_inputs(),
81+
dynamic_shapes=model.get_dynamic_shapes(),
82+
strict=False,
83+
)
84+
7785
return {
7886
"config": config,
7987
"reference_model": reference_model,
8088
"model": model,
8189
"exported_model": exported_model,
82-
# "aoti_path": aoti_path,
90+
"aoti_path": aoti_path,
8391
"executorch_model": executorch_model,
8492
}
8593

@@ -268,8 +276,10 @@ def run_preprocess(
268276
aoti_path = models["aoti_path"]
269277
aoti_model = torch._export.aot_load(aoti_path, "cpu")
270278
aoti_image, aoti_ar = aoti_model(image_tensor, inscribed_size, best_resolution)
271-
self.assertTrue(torch.allclose(reference_image, aoti_image))
272-
self.assertEqual(reference_ar, aoti_ar.tolist())
279+
assert torch.allclose(reference_image, aoti_image)
280+
assert (
281+
reference_ar == aoti_ar.tolist()
282+
), f"AOTI model: expected {reference_ar} but got {aoti_ar.tolist()}"
273283

274284
# This test setup mirrors the one in torchtune:
275285
# https://github.com/pytorch/torchtune/blob/main/tests/torchtune/models/clip/test_clip_image_transform.py

0 commit comments

Comments
 (0)