Skip to content

Commit 64833eb

Browse files
committed
Update on "Enable aoti for preprocess"
Update torch nightly pin to 11/01 after: pytorch/pytorch#137063 Test Plan: With pytorch/pytorch#137063: ``` pytest -c /dev/null -v -n auto examples/models/llama3_2_vision/preprocess/ ``` [ghstack-poisoned]
2 parents 0ef9504 + 0faa31d commit 64833eb

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

examples/models/llama3_2_vision/preprocess/test_preprocess.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def run_preprocess(
245245
image_tensor, inscribed_size, best_resolution
246246
)
247247
eager_ar = eager_ar.tolist()
248-
assert torch.allclose(reference_image, eager_image, rtol=1e-4, atol=1e-4)
248+
assert_expected(eager_image, reference_image, rtol=0, atol=1e-4)
249249
assert (
250250
reference_ar == eager_ar
251251
), f"Eager model: expected {reference_ar} but got {eager_ar}"
@@ -256,7 +256,7 @@ def run_preprocess(
256256
image_tensor, inscribed_size, best_resolution
257257
)
258258
exported_ar = exported_ar.tolist()
259-
assert torch.allclose(reference_image, exported_image, rtol=1e-4, atol=1e-4)
259+
assert_expected(exported_image, reference_image, rtol=0, atol=1e-4)
260260
assert (
261261
reference_ar == exported_ar
262262
), f"Exported model: expected {reference_ar} but got {exported_ar}"
@@ -267,7 +267,7 @@ def run_preprocess(
267267
et_image, et_ar = executorch_module.forward(
268268
(image_tensor, inscribed_size, best_resolution)
269269
)
270-
assert torch.allclose(reference_image, et_image, rtol=1e-4, atol=1e-4)
270+
assert_expected(et_image, reference_image, rtol=0, atol=1e-4)
271271
assert (
272272
reference_ar == et_ar.tolist()
273273
), f"Executorch model: expected {reference_ar} but got {et_ar.tolist()}"
@@ -276,7 +276,7 @@ def run_preprocess(
276276
aoti_path = models["aoti_path"]
277277
aoti_model = torch._export.aot_load(aoti_path, "cpu")
278278
aoti_image, aoti_ar = aoti_model(image_tensor, inscribed_size, best_resolution)
279-
assert torch.allclose(reference_image, aoti_image)
279+
assert_expected(aoti_image, reference_image, rtol=0, atol=1e-4)
280280
assert (
281281
reference_ar == aoti_ar.tolist()
282282
), f"AOTI model: expected {reference_ar} but got {aoti_ar.tolist()}"

0 commit comments

Comments
 (0)