Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
non_blocking: bool = False,
stream: Optional[torch.cuda.Stream] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage=False,
low_cpu_mem_usage: bool = False,
onload_self: bool = True,
) -> None:
self.modules = modules
Expand Down Expand Up @@ -498,6 +498,8 @@ def _apply_group_offloading_block_level(
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
"""
if stream is not None and num_blocks_per_group != 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is potentially breaking no? What if there is existing code with num_blocks_per_group>1 and stream=True? If so, it might be better to raise a warning and set the num_blocks_per_group to 1 if stream is True?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has been addressed in #11425

raise ValueError(f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}.")

# Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading = set()
Expand All @@ -521,20 +523,16 @@ def _apply_group_offloading_block_level(
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=stream is None,
onload_self=True,
)
matched_module_groups.append(group)
for j in range(i, i + len(current_modules)):
modules_with_group_offloading.add(f"{name}.{j}")

# Apply group offloading hooks to the module groups
for i, group in enumerate(matched_module_groups):
next_group = (
matched_module_groups[i + 1] if i + 1 < len(matched_module_groups) and stream is not None else None
)

for group_module in group.modules:
_apply_group_offloading_hook(group_module, group, next_group)
_apply_group_offloading_hook(group_module, group, None)

# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
# when the forward pass of this module is called. This is because the top-level module is not
Expand All @@ -560,8 +558,10 @@ def _apply_group_offloading_block_level(
record_stream=False,
onload_self=True,
)
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
_apply_group_offloading_hook(module, unmatched_group, next_group)
if stream is None:
_apply_group_offloading_hook(module, unmatched_group, None)
else:
_apply_lazy_group_offloading_hook(module, unmatched_group, None)


def _apply_group_offloading_leaf_level(
Expand Down
55 changes: 55 additions & 0 deletions tests/hooks/test_group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import gc
import unittest

Expand All @@ -20,6 +21,7 @@
from diffusers.models import ModelMixin
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import get_logger
from diffusers.utils.import_utils import compare_versions
from diffusers.utils.testing_utils import require_torch_gpu, torch_device


Expand Down Expand Up @@ -58,6 +60,39 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


# This model implementation contains one type of block (single_blocks) instantiated before another type of block (double_blocks).
# The invocation order of these blocks, however, is first the double_blocks and then the single_blocks.
# With group offloading implementation before https://github.com/huggingface/diffusers/pull/11375, such a modeling implementation
# would result in a device mismatch error because of the assumptions made by the code. The failure case occurs when using:
# offload_type="block_level", num_blocks_per_group=2, use_stream=True
# Post the linked PR, the implementation will work as expected.
class DummyModelWithMultipleBlocks(ModelMixin):
def __init__(
self, in_features: int, hidden_features: int, out_features: int, num_layers: int, num_single_layers: int
) -> None:
super().__init__()

self.linear_1 = torch.nn.Linear(in_features, hidden_features)
self.activation = torch.nn.ReLU()
self.single_blocks = torch.nn.ModuleList(
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_single_layers)]
)
self.double_blocks = torch.nn.ModuleList(
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
)
self.linear_2 = torch.nn.Linear(hidden_features, out_features)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear_1(x)
x = self.activation(x)
for block in self.double_blocks:
x = block(x)
for block in self.single_blocks:
x = block(x)
x = self.linear_2(x)
return x


class DummyPipeline(DiffusionPipeline):
model_cpu_offload_seq = "model"

Expand Down Expand Up @@ -212,3 +247,23 @@ def test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module
pipe.enable_sequential_cpu_offload()
with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"):
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)

def test_block_level_stream_with_invocation_order_different_from_initialization_order(self):
if torch.device(torch_device).type != "cuda":
return
model = DummyModelWithMultipleBlocks(
in_features=self.in_features,
hidden_features=self.hidden_features,
out_features=self.out_features,
num_layers=self.num_layers,
num_single_layers=self.num_layers + 1,
)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)

context = contextlib.nullcontext()
if compare_versions("diffusers", "<=", "0.33.0"):
# Will raise a device mismatch RuntimeError mentioning weights are on CPU but input is on device
context = self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device")

with context:
model(self.input)
Loading