Skip to content

Commit 0007969

Browse files
authored
Merge branch 'main' into metadata-lora
2 parents 29ff6f1 + 16c955c commit 0007969

File tree

1 file changed

+29
-5
lines changed

1 file changed

+29
-5
lines changed

tests/models/test_modeling_common.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1829,6 +1829,10 @@ def test_push_to_hub_library_name(self):
18291829
delete_repo(self.repo_id, token=TOKEN)
18301830

18311831

1832+
@require_torch_gpu
1833+
@require_torch_2
1834+
@is_torch_compile
1835+
@slow
18321836
class TorchCompileTesterMixin:
18331837
def setUp(self):
18341838
# clean up the VRAM before each test
@@ -1844,12 +1848,7 @@ def tearDown(self):
18441848
gc.collect()
18451849
backend_empty_cache(torch_device)
18461850

1847-
@require_torch_gpu
1848-
@require_torch_2
1849-
@is_torch_compile
1850-
@slow
18511851
def test_torch_compile_recompilation_and_graph_break(self):
1852-
torch.compiler.reset()
18531852
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
18541853

18551854
model = self.model_class(**init_dict).to(torch_device)
@@ -1863,6 +1862,31 @@ def test_torch_compile_recompilation_and_graph_break(self):
18631862
_ = model(**inputs_dict)
18641863
_ = model(**inputs_dict)
18651864

1865+
def test_compile_with_group_offloading(self):
1866+
torch._dynamo.config.cache_size_limit = 10000
1867+
1868+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1869+
model = self.model_class(**init_dict)
1870+
1871+
if not getattr(model, "_supports_group_offloading", True):
1872+
return
1873+
1874+
model.eval()
1875+
# TODO: Can test for other group offloading kwargs later if needed.
1876+
group_offload_kwargs = {
1877+
"onload_device": "cuda",
1878+
"offload_device": "cpu",
1879+
"offload_type": "block_level",
1880+
"num_blocks_per_group": 1,
1881+
"use_stream": True,
1882+
"non_blocking": True,
1883+
}
1884+
model.enable_group_offload(**group_offload_kwargs)
1885+
model.compile()
1886+
with torch.no_grad():
1887+
_ = model(**inputs_dict)
1888+
_ = model(**inputs_dict)
1889+
18661890

18671891
@slow
18681892
@require_torch_2

0 commit comments

Comments
 (0)