diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 36eb2c1ef488..5e7be62342c3 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -2059,6 +2059,7 @@ def test_torch_compile_recompilation_and_graph_break(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict).to(torch_device) + model.eval() model = torch.compile(model, fullgraph=True) with ( @@ -2076,6 +2077,7 @@ def test_torch_compile_repeated_blocks(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict).to(torch_device) + model.eval() model.compile_repeated_blocks(fullgraph=True) recompile_limit = 1 @@ -2098,7 +2100,6 @@ def test_compile_with_group_offloading(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) - model.eval() # TODO: Can test for other group offloading kwargs later if needed. group_offload_kwargs = { @@ -2111,11 +2112,11 @@ def test_compile_with_group_offloading(self): } model.enable_group_offload(**group_offload_kwargs) model.compile() + with torch.no_grad(): _ = model(**inputs_dict) _ = model(**inputs_dict) - @require_torch_version_greater("2.7.1") def test_compile_on_different_shapes(self): if self.different_shapes_for_compilation is None: pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.") @@ -2123,6 +2124,7 @@ def test_compile_on_different_shapes(self): init_dict, _ = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict).to(torch_device) + model.eval() model = torch.compile(model, fullgraph=True, dynamic=True) for height, width in self.different_shapes_for_compilation: @@ -2130,6 +2132,26 @@ def test_compile_on_different_shapes(self): inputs_dict = self.prepare_dummy_input(height=height, width=width) _ = model(**inputs_dict) + def test_compile_works_with_aot(self): + from torch._inductor.package import load_package + + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict).to(torch_device) + exported_model = torch.export.export(model, args=(), kwargs=inputs_dict) + + with tempfile.TemporaryDirectory() as tmpdir: + package_path = os.path.join(tmpdir, f"{self.model_class.__name__}.pt2") + _ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path) + assert os.path.exists(package_path) + loaded_binary = load_package(package_path, run_single_threaded=True) + + model.forward = loaded_binary + + with torch.no_grad(): + _ = model(**inputs_dict) + _ = model(**inputs_dict) + @slow @require_torch_2