Skip to content

Commit 335dca8

Browse files
committed
Address review feedback for group offload pinning
1 parent 2e8f538 commit 335dca8

File tree

3 files changed

+27
-41
lines changed

3 files changed

+27
-41
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 16 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
from .hooks import HookRegistry, ModelHook
2828

2929

30+
VALID_PIN_GROUPS = {"all", "first_last"}
31+
32+
3033
if is_accelerate_available():
3134
from accelerate.hooks import AlignDevicesHook, CpuOffload
3235
from accelerate.utils import send_to_device
@@ -302,36 +305,19 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
302305
# method is the onload_leader of the group.
303306
if self.group.onload_leader is None:
304307
self.group.onload_leader = module
308+
is_leader = self.group.onload_leader == module
309+
should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
310+
should_orchestrate = self.group.pinned or is_leader
311+
312+
if should_orchestrate:
313+
# Pinned groups keep their params on the onload device; orchestrate onload/prefetch/sync every call.
314+
if self.group.pinned:
315+
if is_leader and not self._is_group_on_device():
316+
self.group.onload_()
317+
else:
318+
if is_leader and self.group.onload_self:
319+
self.group.onload_()
305320

306-
if self.group.pinned:
307-
if self.group.onload_leader == module and not self._is_group_on_device():
308-
self.group.onload_()
309-
310-
should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
311-
if should_onload_next_group:
312-
self.next_group.onload_()
313-
314-
should_synchronize = (
315-
not self.group.onload_self
316-
and self.group.stream is not None
317-
and not should_onload_next_group
318-
and not self.group.record_stream
319-
)
320-
if should_synchronize:
321-
self.group.stream.synchronize()
322-
323-
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
324-
kwargs = self._send_kwargs_to_device(kwargs)
325-
return args, kwargs
326-
327-
# If the current module is the onload_leader of the group, we onload the group if it is supposed
328-
# to onload itself. In the case of using prefetching with streams, we onload the next group if
329-
# it is not supposed to onload itself.
330-
if self.group.onload_leader == module:
331-
if self.group.onload_self:
332-
self.group.onload_()
333-
334-
should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
335321
if should_onload_next_group:
336322
self.next_group.onload_()
337323

@@ -345,9 +331,7 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
345331
# If this group didn't onload itself, it means it was asynchronously onloaded by the
346332
# previous group. We need to synchronize the side stream to ensure parameters
347333
# are completely loaded to proceed with forward pass. Without this, uninitialized
348-
# weights will be used in the computation, leading to incorrect results
349-
# Also, we should only do this synchronization if we don't already do it from the sync call in
350-
# self.next_group.onload_, hence the `not should_onload_next_group` check.
334+
# weights will be used in the computation, leading to incorrect results.
351335
self.group.stream.synchronize()
352336

353337
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
@@ -546,9 +530,6 @@ def pre_forward(self, module, *args, **kwargs):
546530
return args, kwargs
547531

548532

549-
VALID_PIN_GROUPS = {"all", "first_last"}
550-
551-
552533
def _validate_pin_groups(pin_groups: Optional[Union[str, Callable]]) -> Optional[Union[str, Callable]]:
553534
if pin_groups is None or callable(pin_groups):
554535
return pin_groups
@@ -708,9 +689,6 @@ def apply_group_offloading(
708689

709690

710691
def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
711-
registry = HookRegistry.check_if_exists_or_initialize(module)
712-
registry._group_offload_pin_groups = config.pin_groups
713-
714692
if config.offload_type == GroupOffloadingType.BLOCK_LEVEL:
715693
_apply_group_offloading_block_level(module, config)
716694
elif config.offload_type == GroupOffloadingType.LEAF_LEVEL:

src/diffusers/models/modeling_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
252252
_parallel_config = None
253253
_cp_plan = None
254254
_skip_keys = None
255+
_group_offload_block_modules = None
255256

256257
def __init__(self):
257258
super().__init__()
@@ -556,6 +557,11 @@ def enable_group_offload(
556557
... use_stream=True,
557558
... )
558559
```
560+
561+
Args:
562+
pin_groups (`"first_last"` | `"all"` | `Callable`, *optional*):
563+
Optionally keep selected groups on the onload device permanently. See
564+
[`~hooks.group_offloading.apply_group_offloading`] for details.
559565
"""
560566
from ..hooks import apply_group_offloading
561567

tests/hooks/test_group_offloading.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import contextlib
1616
import gc
1717
import unittest
18+
from typing import Any, Iterable, List, Optional, Sequence, Union
1819

1920
import torch
2021
from parameterized import parameterized
@@ -34,7 +35,6 @@
3435
torch_device,
3536
)
3637

37-
from typing import Any, Iterable, List, Optional, Sequence, Union
3838

3939
class DummyBlock(torch.nn.Module):
4040
def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
@@ -217,8 +217,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
217217
x = block(x)
218218
x = self.norm(x)
219219
return x
220-
220+
221221
# Test for https://github.com/huggingface/diffusers/pull/12747
222+
223+
222224
class DummyCallableBySubmodule:
223225
"""
224226
Callable group offloading pinner that pins first and last DummyBlock
@@ -633,7 +635,7 @@ def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=Non
633635
"layers_per_block": 1,
634636
}
635637
return init_dict
636-
638+
637639
def test_block_level_offloading_with_pin_groups_stay_on_device(self):
638640
if torch.device(torch_device).type not in ["cuda", "xpu"]:
639641
return

0 commit comments

Comments
 (0)