Skip to content

Commit 53659d8

Browse files
committed
Eagerly write disk offload tensors for safetensor checks
1 parent 0cbd079 commit 53659d8

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,9 @@ def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None
292292
self.config = config
293293

294294
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
295+
# For disk offload we materialize the safetensor files upfront so callers can inspect them immediately.
296+
if self.group.offload_to_disk_path is not None and self.group.offload_leader == module:
297+
self.group.offload_()
295298
return module
296299

297300
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
@@ -545,12 +548,15 @@ def pre_forward(self, module, *args, **kwargs):
545548

546549
VALID_PIN_GROUPS = {"all", "first_last"}
547550

551+
548552
def _validate_pin_groups(pin_groups: Optional[Union[str, Callable]]) -> Optional[Union[str, Callable]]:
549553
if pin_groups is None or callable(pin_groups):
550554
return pin_groups
551555
if isinstance(pin_groups, str) and pin_groups in VALID_PIN_GROUPS:
552556
return pin_groups
553-
raise ValueError(f"`pin_groups` must be None, {', '.join(repr(v) for v in sorted(VALID_PIN_GROUPS))}, or a callable.")
557+
raise ValueError(
558+
f"`pin_groups` must be None, {', '.join(repr(v) for v in sorted(VALID_PIN_GROUPS))}, or a callable."
559+
)
554560

555561

556562
def apply_group_offloading(

tests/hooks/test_group_offloading.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -569,9 +569,7 @@ def test_error_raised_if_pin_groups_received_invalid_value(self):
569569
"use_stream": True,
570570
}
571571
model = self.get_model()
572-
with self.assertRaisesRegex(
573-
ValueError, "`pin_groups` must be None, 'all', 'first_last', or a callable."
574-
):
572+
with self.assertRaisesRegex(ValueError, "`pin_groups` must be None, 'all', 'first_last', or a callable."):
575573
model.enable_group_offload(
576574
**default_parameters,
577575
pin_groups="invalid value",

0 commit comments

Comments
 (0)