Skip to content

Commit 6f5887e

Browse files
committed
Address review feedback for group offload pinning
1 parent 33d8b52 commit 6f5887e

File tree

2 files changed

+74
-35
lines changed

2 files changed

+74
-35
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 65 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ class GroupOffloadingConfig:
6060
offload_to_disk_path: Optional[str] = None
6161
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
6262
block_modules: Optional[List[str]] = None
63+
exclude_kwargs: Optional[List[str]] = None
64+
module_prefix: Optional[str] = ""
6365
pin_groups: Optional[Union[str, Callable]] = None
6466

6567

@@ -156,27 +158,27 @@ def _pinned_memory_tensors(self):
156158
finally:
157159
pinned_dict = None
158160

159-
def _transfer_tensor_to_device(self, tensor, source_tensor):
161+
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream=None):
160162
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
161163
if self.record_stream:
162-
tensor.data.record_stream(self._torch_accelerator_module.current_stream())
164+
tensor.data.record_stream(default_stream)
163165

164-
def _process_tensors_from_modules(self, pinned_memory=None):
166+
def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
165167
for group_module in self.modules:
166168
for param in group_module.parameters():
167169
source = pinned_memory[param] if pinned_memory else param.data
168-
self._transfer_tensor_to_device(param, source)
170+
self._transfer_tensor_to_device(param, source, default_stream)
169171
for buffer in group_module.buffers():
170172
source = pinned_memory[buffer] if pinned_memory else buffer.data
171-
self._transfer_tensor_to_device(buffer, source)
173+
self._transfer_tensor_to_device(buffer, source, default_stream)
172174

173175
for param in self.parameters:
174176
source = pinned_memory[param] if pinned_memory else param.data
175-
self._transfer_tensor_to_device(param, source)
177+
self._transfer_tensor_to_device(param, source, default_stream)
176178

177179
for buffer in self.buffers:
178180
source = pinned_memory[buffer] if pinned_memory else buffer.data
179-
self._transfer_tensor_to_device(buffer, source)
181+
self._transfer_tensor_to_device(buffer, source, default_stream)
180182

181183
def _onload_from_disk(self):
182184
if self.stream is not None:
@@ -211,10 +213,11 @@ def _onload_from_memory(self):
211213
self.stream.synchronize()
212214

213215
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
216+
default_stream = self._torch_accelerator_module.current_stream() if self.stream is not None else None
214217
with context:
215218
if self.stream is not None:
216219
with self._pinned_memory_tensors() as pinned_memory:
217-
self._process_tensors_from_modules(pinned_memory)
220+
self._process_tensors_from_modules(pinned_memory, default_stream=default_stream)
218221
else:
219222
self._process_tensors_from_modules(None)
220223

@@ -308,13 +311,16 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
308311
self.next_group.onload_()
309312

310313
should_synchronize = (
311-
not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
314+
not self.group.onload_self
315+
and self.group.stream is not None
316+
and not should_onload_next_group
317+
and not self.group.record_stream
312318
)
313319
if should_synchronize:
314320
self.group.stream.synchronize()
315321

316322
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
317-
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
323+
kwargs = self._send_kwargs_to_device(kwargs)
318324
return args, kwargs
319325

320326
# If the current module is the onload_leader of the group, we onload the group if it is supposed
@@ -329,7 +335,10 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
329335
self.next_group.onload_()
330336

331337
should_synchronize = (
332-
not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
338+
not self.group.onload_self
339+
and self.group.stream is not None
340+
and not should_onload_next_group
341+
and not self.group.record_stream
333342
)
334343
if should_synchronize:
335344
# If this group didn't onload itself, it means it was asynchronously onloaded by the
@@ -341,7 +350,7 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
341350
self.group.stream.synchronize()
342351

343352
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
344-
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
353+
kwargs = self._send_kwargs_to_device(kwargs)
345354
return args, kwargs
346355

347356
def post_forward(self, module: torch.nn.Module, output):
@@ -360,10 +369,19 @@ def _is_group_on_device(self) -> bool:
360369
tensors.extend(self.group.parameters)
361370
tensors.extend(self.group.buffers)
362371

363-
if len(tensors) == 0:
364-
return True
372+
return len(tensors) > 0 and all(t.device == self.group.onload_device for t in tensors)
365373

366-
return all(t.device == self.group.onload_device for t in tensors)
374+
def _send_kwargs_to_device(self, kwargs):
375+
exclude_kwargs = self.config.exclude_kwargs or []
376+
if exclude_kwargs:
377+
moved_kwargs = send_to_device(
378+
{k: v for k, v in kwargs.items() if k not in exclude_kwargs},
379+
self.group.onload_device,
380+
non_blocking=self.group.non_blocking,
381+
)
382+
kwargs.update(moved_kwargs)
383+
return kwargs
384+
return send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
367385

368386

369387
class LazyPrefetchGroupOffloadingHook(ModelHook):
@@ -524,6 +542,17 @@ def pre_forward(self, module, *args, **kwargs):
524542
return args, kwargs
525543

526544

545+
def _normalize_pin_groups(pin_groups: Optional[Union[str, Callable]]) -> Optional[Union[str, Callable]]:
546+
if isinstance(pin_groups, str):
547+
normalized_pin_groups = pin_groups.lower()
548+
if normalized_pin_groups not in {"first_last", "all"}:
549+
raise ValueError("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable.")
550+
return normalized_pin_groups
551+
if pin_groups is not None and not callable(pin_groups):
552+
raise ValueError("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable.")
553+
return pin_groups
554+
555+
527556
def apply_group_offloading(
528557
module: torch.nn.Module,
529558
onload_device: Union[str, torch.device],
@@ -536,6 +565,7 @@ def apply_group_offloading(
536565
low_cpu_mem_usage: bool = False,
537566
offload_to_disk_path: Optional[str] = None,
538567
block_modules: Optional[List[str]] = None,
568+
exclude_kwargs: Optional[List[str]] = None,
539569
pin_groups: Optional[Union[str, Callable]] = None,
540570
) -> None:
541571
r"""
@@ -597,6 +627,10 @@ def apply_group_offloading(
597627
block_modules (`List[str]`, *optional*):
598628
List of module names that should be treated as blocks for offloading. If provided, only these modules
599629
will be considered for block-level offloading. If not provided, the default block detection logic will be used.
630+
exclude_kwargs (`List[str]`, *optional*):
631+
List of kwarg keys that should not be processed by `send_to_device`. This is useful for mutable state like
632+
caching lists that need to maintain their object identity across forward passes. If not provided, will be
633+
inferred from the module's `_skip_keys` attribute if it exists.
600634
pin_groups (`"first_last"` or `"all"` or `Callable`, *optional*, defaults to `None`):
601635
Optionally keeps selected groups on the onload device permanently. Use `"first_last"` to pin the first
602636
and last parameter-bearing groups, `"all"` to pin every parameter-bearing group, or pass a callable that
@@ -640,17 +674,14 @@ def apply_group_offloading(
640674
if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None:
641675
raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.")
642676

643-
normalized_pin_groups = pin_groups
644-
if isinstance(pin_groups, str):
645-
normalized_pin_groups = pin_groups.lower()
646-
if normalized_pin_groups not in {"first_last", "all"}:
647-
raise ValueError("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable.")
648-
elif pin_groups is not None and not callable(pin_groups):
649-
raise ValueError("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable.")
677+
pin_groups = _normalize_pin_groups(pin_groups)
678+
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
650679

651-
pin_groups = normalized_pin_groups
680+
if block_modules is None:
681+
block_modules = getattr(module, "_group_offload_block_modules", None)
652682

653-
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
683+
if exclude_kwargs is None:
684+
exclude_kwargs = getattr(module, "_skip_keys", None)
654685

655686
config = GroupOffloadingConfig(
656687
onload_device=onload_device,
@@ -663,6 +694,8 @@ def apply_group_offloading(
663694
low_cpu_mem_usage=low_cpu_mem_usage,
664695
offload_to_disk_path=offload_to_disk_path,
665696
block_modules=block_modules,
697+
exclude_kwargs=exclude_kwargs,
698+
module_prefix="",
666699
pin_groups=pin_groups,
667700
)
668701
_apply_group_offloading(module, config)
@@ -701,7 +734,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
701734

702735
for name, submodule in module.named_children():
703736
# Check if this is an explicitly defined block module
704-
if name in block_modules:
737+
if block_modules and name in block_modules:
705738
# Apply block offloading to the specified submodule
706739
_apply_block_offloading_to_submodule(
707740
submodule, name, config, modules_with_group_offloading, matched_module_groups
@@ -713,7 +746,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
713746
if len(current_modules) == 0:
714747
continue
715748

716-
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
749+
group_id = f"{config.module_prefix}{name}_{i}_{i + len(current_modules) - 1}"
717750
group = ModuleGroup(
718751
modules=current_modules,
719752
offload_device=config.offload_device,
@@ -766,7 +799,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
766799
stream=None,
767800
record_stream=False,
768801
onload_self=True,
769-
group_id=f"{module.__class__.__name__}_unmatched_group",
802+
group_id=f"{config.module_prefix}{module.__class__.__name__}_unmatched_group",
770803
)
771804
if config.stream is None:
772805
_apply_group_offloading_hook(module, unmatched_group, config=config)
@@ -797,7 +830,7 @@ def _apply_block_offloading_to_submodule(
797830
if len(current_modules) == 0:
798831
continue
799832

800-
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
833+
group_id = f"{config.module_prefix}{name}_{i}_{i + len(current_modules) - 1}"
801834
group = ModuleGroup(
802835
modules=current_modules,
803836
offload_device=config.offload_device,
@@ -829,7 +862,7 @@ def _apply_block_offloading_to_submodule(
829862
record_stream=config.record_stream,
830863
low_cpu_mem_usage=config.low_cpu_mem_usage,
831864
onload_self=True,
832-
group_id=name,
865+
group_id=f"{config.module_prefix}{name}",
833866
)
834867
matched_module_groups.append(group)
835868
modules_with_group_offloading.add(name)
@@ -859,7 +892,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
859892
record_stream=config.record_stream,
860893
low_cpu_mem_usage=config.low_cpu_mem_usage,
861894
onload_self=True,
862-
group_id=name,
895+
group_id=f"{config.module_prefix}{name}",
863896
)
864897
_apply_group_offloading_hook(submodule, group, config=config)
865898
modules_with_group_offloading.add(name)
@@ -906,7 +939,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
906939
record_stream=config.record_stream,
907940
low_cpu_mem_usage=config.low_cpu_mem_usage,
908941
onload_self=True,
909-
group_id=name,
942+
group_id=f"{config.module_prefix}{name}",
910943
)
911944
_apply_group_offloading_hook(parent_module, group, config=config)
912945

@@ -962,7 +995,7 @@ def _apply_lazy_group_offloading_hook(
962995
hook = GroupOffloadingHook(group, config=config)
963996
registry.register_hook(hook, _GROUP_OFFLOADING)
964997

965-
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook(pin_groups = config.pin_groups)
998+
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook(pin_groups=config.pin_groups)
966999
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
9671000

9681001

src/diffusers/models/modeling_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,9 @@ def enable_group_offload(
531531
record_stream: bool = False,
532532
low_cpu_mem_usage=False,
533533
offload_to_disk_path: Optional[str] = None,
534-
pin_groups: Optional[Union[str, Callable]] = None
534+
block_modules: Optional[str] = None,
535+
exclude_kwargs: Optional[str] = None,
536+
pin_groups: Optional[Union[str, Callable]] = None,
535537
) -> None:
536538
r"""
537539
Activates group offloading for the current model.
@@ -571,7 +573,10 @@ def enable_group_offload(
571573
f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please "
572574
f"open an issue at https://github.com/huggingface/diffusers/issues."
573575
)
574-
block_modules = getattr(self, "_group_offload_block_modules", None)
576+
if block_modules is None:
577+
block_modules = getattr(self, "_group_offload_block_modules", None)
578+
if exclude_kwargs is None:
579+
exclude_kwargs = getattr(self, "_skip_keys", None)
575580
apply_group_offloading(
576581
module=self,
577582
onload_device=onload_device,
@@ -584,7 +589,8 @@ def enable_group_offload(
584589
low_cpu_mem_usage=low_cpu_mem_usage,
585590
offload_to_disk_path=offload_to_disk_path,
586591
block_modules=block_modules,
587-
pin_groups=pin_groups
592+
exclude_kwargs=exclude_kwargs,
593+
pin_groups=pin_groups,
588594
)
589595

590596
def set_attention_backend(self, backend: str) -> None:

0 commit comments

Comments
 (0)