Skip to content

Commit fd69611

Browse files
committed
Address review feedback on group offloading
1 parent 581f051 commit fd69611

File tree

3 files changed

+36
-29
lines changed

3 files changed

+36
-29
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def _pinned_memory_tensors(self):
161161
finally:
162162
pinned_dict = None
163163

164-
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream=None):
164+
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
165165
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
166166
if self.record_stream:
167167
tensor.data.record_stream(default_stream)
@@ -295,7 +295,11 @@ def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None
295295
self.config = config
296296

297297
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
298-
# For disk offload we materialize the safetensor files upfront so callers can inspect them immediately.
298+
# Disk offload only: materialize the safetensor files up front so they exist right after enable_group_offload.
299+
# Needed for flows/tests that inspect the offload dir before the first forward
300+
# eg: model.enable_group_offload(..., offload_to_disk_path=tmpdir)
301+
# assert glob.glob(f"{tmpdir}/*.safetensors")
302+
# In-memory offload stays lazy to allow adapter loading before the first forward.
299303
if self.group.offload_to_disk_path is not None and self.group.offload_leader == module:
300304
self.group.offload_()
301305
return module
@@ -305,18 +309,18 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
305309
# method is the onload_leader of the group.
306310
if self.group.onload_leader is None:
307311
self.group.onload_leader = module
308-
is_leader = self.group.onload_leader == module
312+
309313
should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
310-
should_orchestrate = self.group.pinned or is_leader
311314

312-
if should_orchestrate:
313-
# Pinned groups keep their params on the onload device; orchestrate onload/prefetch/sync every call.
315+
if self.group.onload_leader == module:
316+
# If the current module is the onload_leader of the group, we onload the group if it is supposed
317+
# to onload itself. In the case of using prefetching with streams, we onload the next group if
318+
# it is not supposed to onload itself.
314319
if self.group.pinned:
315-
if is_leader and not self._is_group_on_device():
316-
self.group.onload_()
317-
else:
318-
if is_leader and self.group.onload_self:
320+
if not self._is_group_on_device():
319321
self.group.onload_()
322+
elif self.group.onload_self:
323+
self.group.onload_()
320324

321325
if should_onload_next_group:
322326
self.next_group.onload_()
@@ -335,18 +339,10 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
335339
self.group.stream.synchronize()
336340

337341
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
338-
kwargs = self._send_kwargs_to_device(kwargs)
339-
return args, kwargs
340-
341-
def post_forward(self, module: torch.nn.Module, output):
342-
if self.group.pinned:
343-
return output
344-
345-
if self.group.offload_leader == module:
346-
self.group.offload_()
347-
return output
348342

349-
def _send_kwargs_to_device(self, kwargs):
343+
# Some Autoencoder models use a feature cache that is passed through submodules and modified in place.
344+
# The `send_to_device` call returns a copy of this feature cache object which breaks the inplace updates.
345+
# Use `exclude_kwargs` to mark these cache features so they are not moved.
350346
exclude_kwargs = self.config.exclude_kwargs or []
351347
if exclude_kwargs:
352348
moved_kwargs = send_to_device(
@@ -355,8 +351,19 @@ def _send_kwargs_to_device(self, kwargs):
355351
non_blocking=self.group.non_blocking,
356352
)
357353
kwargs.update(moved_kwargs)
358-
return kwargs
359-
return send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
354+
else:
355+
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
356+
357+
return args, kwargs
358+
359+
def post_forward(self, module: torch.nn.Module, output):
360+
# Pinned groups stay resident, otherwise offload when the offload leader finishes.
361+
if self.group.pinned:
362+
return output
363+
364+
if self.group.offload_leader == module:
365+
self.group.offload_()
366+
return output
360367

361368
def _is_group_on_device(self) -> bool:
362369
tensors = []
@@ -535,6 +542,10 @@ def _validate_pin_groups(pin_groups: Optional[Union[str, Callable]]) -> Optional
535542
return pin_groups
536543
if isinstance(pin_groups, str) and pin_groups in VALID_PIN_GROUPS:
537544
return pin_groups
545+
elif isinstance(pin_groups, str) and pin_groups not in VALID_PIN_GROUPS:
546+
raise ValueError(
547+
f"`pin_groups` must be None, {', '.join(repr(v) for v in sorted(VALID_PIN_GROUPS))}, or a callable."
548+
)
538549
raise ValueError(
539550
f"`pin_groups` must be None, {', '.join(repr(v) for v in sorted(VALID_PIN_GROUPS))}, or a callable."
540551
)

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -962,11 +962,12 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
962962
"""
963963

964964
_supports_gradient_checkpointing = False
965+
# Group offloading treats the top-level latent bridge and encode/decode stages as natural block boundaries.
966+
# These modules encapsulate most parameters and map cleanly to the model’s major subcomponents.
965967
_group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"]
966968
# keys toignore when AlignDeviceHook moves inputs/outputs between devices
967969
# these are shared mutable state modified in-place
968970
_skip_keys = ["feat_cache", "feat_idx"]
969-
_group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"]
970971

971972
@register_to_config
972973
def __init__(

src/diffusers/models/modeling_utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -557,11 +557,6 @@ def enable_group_offload(
557557
... use_stream=True,
558558
... )
559559
```
560-
561-
Args:
562-
pin_groups (`"first_last"` | `"all"` | `Callable`, *optional*):
563-
Optionally keep selected groups on the onload device permanently. See
564-
[`~hooks.group_offloading.apply_group_offloading`] for details.
565560
"""
566561
from ..hooks import apply_group_offloading
567562

0 commit comments

Comments
 (0)