|
15 | 15 |
|
16 | 16 | import copy |
17 | 17 | import gc |
| 18 | +import glob |
18 | 19 | import inspect |
19 | 20 | import json |
20 | 21 | import os |
@@ -1608,6 +1609,35 @@ def test_group_offloading_with_layerwise_casting(self, record_stream, offload_ty |
1608 | 1609 | model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) |
1609 | 1610 | _ = model(**inputs_dict)[0] |
1610 | 1611 |
|
| 1612 | + @parameterized.expand([(False, "block_level"), (True, "leaf_level")]) |
| 1613 | + @require_torch_accelerator |
| 1614 | + @torch.no_grad() |
| 1615 | + def test_group_offloading_with_disk(self, record_stream, offload_type): |
| 1616 | + torch.manual_seed(0) |
| 1617 | + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
| 1618 | + model = self.model_class(**init_dict) |
| 1619 | + |
| 1620 | + if not getattr(model, "_supports_group_offloading", True): |
| 1621 | + return |
| 1622 | + |
| 1623 | + torch.manual_seed(0) |
| 1624 | + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
| 1625 | + model = self.model_class(**init_dict) |
| 1626 | + model.eval() |
| 1627 | + additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1} |
| 1628 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 1629 | + model.enable_group_offload( |
| 1630 | + torch_device, |
| 1631 | + offload_type=offload_type, |
| 1632 | + offload_to_disk_path=tmpdir, |
| 1633 | + use_stream=True, |
| 1634 | + record_stream=record_stream, |
| 1635 | + **additional_kwargs, |
| 1636 | + ) |
| 1637 | + has_safetensors = glob.glob(f"{tmpdir}/*.safetensors") |
| 1638 | + assert has_safetensors |
| 1639 | + _ = model(**inputs_dict)[0] |
| 1640 | + |
1611 | 1641 | def test_auto_model(self, expected_max_diff=5e-5): |
1612 | 1642 | if self.forward_requires_fresh_args: |
1613 | 1643 | model = self.model_class(**self.init_dict) |
|
0 commit comments