@@ -76,22 +76,30 @@ def initialize_models(resize_to_max_canvas: bool) -> Dict[str, Any]:
7676 strict = False ,
7777 )
7878
79- # aoti_path = torch._inductor.aot_compile(
80- # exported_model.module(),
81- # model.get_example_inputs(),
82- # )
79+ aoti_path = torch ._inductor .aot_compile (
80+ exported_model .module (),
81+ model .get_example_inputs (),
82+ )
8383
8484 edge_program = to_edge (
8585 exported_model , compile_config = EdgeCompileConfig (_check_ir_validity = False )
8686 )
8787 executorch_model = edge_program .to_executorch ()
8888
89+ # Re-export, as lowering to executorch changes the graph.
90+ exported_model = torch .export .export (
91+ model .get_eager_model (),
92+ model .get_example_inputs (),
93+ dynamic_shapes = model .get_dynamic_shapes (),
94+ strict = False ,
95+ )
96+
8997 return {
9098 "config" : config ,
9199 "reference_model" : reference_model ,
92100 "model" : model ,
93101 "exported_model" : exported_model ,
94- # "aoti_path": aoti_path,
102+ "aoti_path" : aoti_path ,
95103 "executorch_model" : executorch_model ,
96104 }
97105
@@ -237,11 +245,11 @@ def run_preprocess(
237245 self .assertEqual (reference_ar , et_ar .tolist ())
238246
239247 # Run aoti model and check it matches reference model.
240- # aoti_path = models["aoti_path"]
241- # aoti_model = torch._export.aot_load(aoti_path, "cpu")
242- # aoti_image, aoti_ar = aoti_model(image_tensor, inscribed_size, best_resolution)
243- # self.assertTrue(torch.allclose(reference_image, aoti_image))
244- # self.assertEqual(reference_ar, aoti_ar.tolist())
248+ aoti_path = models ["aoti_path" ]
249+ aoti_model = torch ._export .aot_load (aoti_path , "cpu" )
250+ aoti_image , aoti_ar = aoti_model (image_tensor , inscribed_size , best_resolution )
251+ self .assertTrue (torch .allclose (reference_image , aoti_image ))
252+ self .assertEqual (reference_ar , aoti_ar .tolist ())
245253
246254 # This test setup mirrors the one in torchtune:
247255 # https://github.com/pytorch/torchtune/blob/main/tests/torchtune/models/clip/test_clip_image_transform.py
0 commit comments