Skip to content

Commit f87d5e3

Browse files
committed
add a test for group offloading + compilation.
1 parent 0f91f2f commit f87d5e3

File tree

1 file changed

+33
-5
lines changed

1 file changed

+33
-5
lines changed

tests/models/test_modeling_common.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1744,6 +1744,10 @@ def test_push_to_hub_library_name(self):
17441744
delete_repo(self.repo_id, token=TOKEN)
17451745

17461746

1747+
@require_torch_gpu
1748+
@require_torch_2
1749+
@is_torch_compile
1750+
@slow
17471751
class TorchCompileTesterMixin:
17481752
def setUp(self):
17491753
# clean up the VRAM before each test
@@ -1759,12 +1763,7 @@ def tearDown(self):
17591763
gc.collect()
17601764
backend_empty_cache(torch_device)
17611765

1762-
@require_torch_gpu
1763-
@require_torch_2
1764-
@is_torch_compile
1765-
@slow
17661766
def test_torch_compile_recompilation_and_graph_break(self):
1767-
torch.compiler.reset()
17681767
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
17691768

17701769
model = self.model_class(**init_dict).to(torch_device)
@@ -1778,6 +1777,35 @@ def test_torch_compile_recompilation_and_graph_break(self):
17781777
_ = model(**inputs_dict)
17791778
_ = model(**inputs_dict)
17801779

1780+
def test_compilation_with_group_offloading(self):
1781+
torch._dynamo.config.cache_size_limit = 10000
1782+
1783+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1784+
model = self.model_class(**init_dict)
1785+
1786+
if not getattr(model, "_supports_group_offloading", True):
1787+
return
1788+
1789+
model.eval()
1790+
# TODO: Can test for other group offloading kwargs later if needed.
1791+
group_offload_kwargs = {
1792+
"onload_device": "cuda",
1793+
"offload_device": "cpu",
1794+
"offload_type": "block_level",
1795+
"num_blocks_per_group": 1,
1796+
"use_stream": True,
1797+
"non_blocking": True,
1798+
}
1799+
model.enable_group_offload(**group_offload_kwargs)
1800+
model.compile()
1801+
with (
1802+
torch._inductor.utils.fresh_inductor_cache(),
1803+
torch._dynamo.config.patch(error_on_recompile=True),
1804+
torch.no_grad(),
1805+
):
1806+
_ = model(**inputs_dict)
1807+
_ = model(**inputs_dict)
1808+
17811809

17821810
@slow
17831811
@require_torch_2

0 commit comments

Comments
 (0)