From f87d5e34c8b68c2a908297b55bf94ca1b4ac9bd8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 09:24:46 +0530 Subject: [PATCH 1/2] add a test for group offloading + compilation. --- tests/models/test_modeling_common.py | 38 ++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 453690c1c901..9291b2df5171 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1744,6 +1744,10 @@ def test_push_to_hub_library_name(self): delete_repo(self.repo_id, token=TOKEN) +@require_torch_gpu +@require_torch_2 +@is_torch_compile +@slow class TorchCompileTesterMixin: def setUp(self): # clean up the VRAM before each test @@ -1759,12 +1763,7 @@ def tearDown(self): gc.collect() backend_empty_cache(torch_device) - @require_torch_gpu - @require_torch_2 - @is_torch_compile - @slow def test_torch_compile_recompilation_and_graph_break(self): - torch.compiler.reset() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict).to(torch_device) @@ -1778,6 +1777,35 @@ def test_torch_compile_recompilation_and_graph_break(self): _ = model(**inputs_dict) _ = model(**inputs_dict) + def test_compilation_with_group_offloading(self): + torch._dynamo.config.cache_size_limit = 10000 + + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + + if not getattr(model, "_supports_group_offloading", True): + return + + model.eval() + # TODO: Can test for other group offloading kwargs later if needed. + group_offload_kwargs = { + "onload_device": "cuda", + "offload_device": "cpu", + "offload_type": "block_level", + "num_blocks_per_group": 1, + "use_stream": True, + "non_blocking": True, + } + model.enable_group_offload(**group_offload_kwargs) + model.compile() + with ( + torch._inductor.utils.fresh_inductor_cache(), + torch._dynamo.config.patch(error_on_recompile=True), + torch.no_grad(), + ): + _ = model(**inputs_dict) + _ = model(**inputs_dict) + @slow @require_torch_2 From ae9d217e690d56b06fa3ec7ee77e27aaed12aa2c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 10:32:32 +0530 Subject: [PATCH 2/2] tests --- tests/models/test_modeling_common.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 9291b2df5171..5087bd0094a5 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1777,7 +1777,7 @@ def test_torch_compile_recompilation_and_graph_break(self): _ = model(**inputs_dict) _ = model(**inputs_dict) - def test_compilation_with_group_offloading(self): + def test_compile_with_group_offloading(self): torch._dynamo.config.cache_size_limit = 10000 init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -1798,11 +1798,7 @@ def test_compilation_with_group_offloading(self): } model.enable_group_offload(**group_offload_kwargs) model.compile() - with ( - torch._inductor.utils.fresh_inductor_cache(), - torch._dynamo.config.patch(error_on_recompile=True), - torch.no_grad(), - ): + with torch.no_grad(): _ = model(**inputs_dict) _ = model(**inputs_dict)