Skip to content

Commit 8029cd7

Browse files
committed
add test and clarify.
1 parent 4e4842f commit 8029cd7

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ def offload_(self):
231231
# The group is now considered offloaded to disk for the rest of the session.
232232
self._is_offloaded_to_disk = True
233233

234+
# We do this to free up the RAM which is still holding the up tensor data.
234235
for tensor_obj in self.tensor_to_key.keys():
235236
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
236237
return

tests/models/test_modeling_common.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import copy
1717
import gc
18+
import glob
1819
import inspect
1920
import json
2021
import os
@@ -1608,6 +1609,35 @@ def test_group_offloading_with_layerwise_casting(self, record_stream, offload_ty
16081609
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
16091610
_ = model(**inputs_dict)[0]
16101611

1612+
@parameterized.expand([(False, "block_level"), (True, "leaf_level")])
1613+
@require_torch_accelerator
1614+
@torch.no_grad()
1615+
def test_group_offloading_with_disk(self, record_stream, offload_type):
1616+
torch.manual_seed(0)
1617+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1618+
model = self.model_class(**init_dict)
1619+
1620+
if not getattr(model, "_supports_group_offloading", True):
1621+
return
1622+
1623+
torch.manual_seed(0)
1624+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1625+
model = self.model_class(**init_dict)
1626+
model.eval()
1627+
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
1628+
with tempfile.TemporaryDirectory() as tmpdir:
1629+
model.enable_group_offload(
1630+
torch_device,
1631+
offload_type=offload_type,
1632+
offload_to_disk_path=tmpdir,
1633+
use_stream=True,
1634+
record_stream=record_stream,
1635+
**additional_kwargs,
1636+
)
1637+
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
1638+
assert has_safetensors
1639+
_ = model(**inputs_dict)[0]
1640+
16111641
def test_auto_model(self, expected_max_diff=5e-5):
16121642
if self.forward_requires_fresh_args:
16131643
model = self.model_class(**self.init_dict)

0 commit comments

Comments
 (0)