Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
non_blocking: bool = False,
stream: Optional[torch.cuda.Stream] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage=False,
low_cpu_mem_usage: bool = False,
onload_self: bool = True,
) -> None:
self.modules = modules
Expand Down Expand Up @@ -498,6 +498,8 @@ def _apply_group_offloading_block_level(
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
"""
if stream is not None and num_blocks_per_group != 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is potentially breaking no? What if there is existing code with num_blocks_per_group>1 and stream=True? If so, it might be better to raise a warning and set the num_blocks_per_group to 1 if stream is True?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has been addressed in #11425

raise ValueError(f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}.")

# Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading = set()
Expand All @@ -521,20 +523,16 @@ def _apply_group_offloading_block_level(
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=stream is None,
onload_self=True,
)
matched_module_groups.append(group)
for j in range(i, i + len(current_modules)):
modules_with_group_offloading.add(f"{name}.{j}")

# Apply group offloading hooks to the module groups
for i, group in enumerate(matched_module_groups):
next_group = (
matched_module_groups[i + 1] if i + 1 < len(matched_module_groups) and stream is not None else None
)

for group_module in group.modules:
_apply_group_offloading_hook(group_module, group, next_group)
_apply_group_offloading_hook(group_module, group, None)

# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
# when the forward pass of this module is called. This is because the top-level module is not
Expand All @@ -560,8 +558,10 @@ def _apply_group_offloading_block_level(
record_stream=False,
onload_self=True,
)
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
_apply_group_offloading_hook(module, unmatched_group, next_group)
if stream is None:
_apply_group_offloading_hook(module, unmatched_group, None)
else:
_apply_lazy_group_offloading_hook(module, unmatched_group, None)


def _apply_group_offloading_leaf_level(
Expand Down