diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index ac6cf653641b..fd2e16c094f8 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -57,7 +57,7 @@ def __init__( non_blocking: bool = False, stream: Optional[torch.cuda.Stream] = None, record_stream: Optional[bool] = False, - low_cpu_mem_usage=False, + low_cpu_mem_usage: bool = False, onload_self: bool = True, ) -> None: self.modules = modules @@ -498,6 +498,8 @@ def _apply_group_offloading_block_level( option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when the CPU memory is a bottleneck but may counteract the benefits of using streams. """ + if stream is not None and num_blocks_per_group != 1: + raise ValueError(f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}.") # Create module groups for ModuleList and Sequential blocks modules_with_group_offloading = set() @@ -521,7 +523,7 @@ def _apply_group_offloading_block_level( stream=stream, record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, - onload_self=stream is None, + onload_self=True, ) matched_module_groups.append(group) for j in range(i, i + len(current_modules)): @@ -529,12 +531,8 @@ def _apply_group_offloading_block_level( # Apply group offloading hooks to the module groups for i, group in enumerate(matched_module_groups): - next_group = ( - matched_module_groups[i + 1] if i + 1 < len(matched_module_groups) and stream is not None else None - ) - for group_module in group.modules: - _apply_group_offloading_hook(group_module, group, next_group) + _apply_group_offloading_hook(group_module, group, None) # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately # when the forward pass of this module is called. This is because the top-level module is not @@ -560,8 +558,10 @@ def _apply_group_offloading_block_level( record_stream=False, onload_self=True, ) - next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None - _apply_group_offloading_hook(module, unmatched_group, next_group) + if stream is None: + _apply_group_offloading_hook(module, unmatched_group, None) + else: + _apply_lazy_group_offloading_hook(module, unmatched_group, None) def _apply_group_offloading_leaf_level( diff --git a/tests/hooks/test_group_offloading.py b/tests/hooks/test_group_offloading.py index d8f41fc2b1ae..37c5c9451b7e 100644 --- a/tests/hooks/test_group_offloading.py +++ b/tests/hooks/test_group_offloading.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import gc import unittest @@ -20,6 +21,7 @@ from diffusers.models import ModelMixin from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.utils import get_logger +from diffusers.utils.import_utils import compare_versions from diffusers.utils.testing_utils import require_torch_gpu, torch_device @@ -58,6 +60,39 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +# This model implementation contains one type of block (single_blocks) instantiated before another type of block (double_blocks). +# The invocation order of these blocks, however, is first the double_blocks and then the single_blocks. +# With group offloading implementation before https://github.com/huggingface/diffusers/pull/11375, such a modeling implementation +# would result in a device mismatch error because of the assumptions made by the code. The failure case occurs when using: +# offload_type="block_level", num_blocks_per_group=2, use_stream=True +# Post the linked PR, the implementation will work as expected. +class DummyModelWithMultipleBlocks(ModelMixin): + def __init__( + self, in_features: int, hidden_features: int, out_features: int, num_layers: int, num_single_layers: int + ) -> None: + super().__init__() + + self.linear_1 = torch.nn.Linear(in_features, hidden_features) + self.activation = torch.nn.ReLU() + self.single_blocks = torch.nn.ModuleList( + [DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_single_layers)] + ) + self.double_blocks = torch.nn.ModuleList( + [DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)] + ) + self.linear_2 = torch.nn.Linear(hidden_features, out_features) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear_1(x) + x = self.activation(x) + for block in self.double_blocks: + x = block(x) + for block in self.single_blocks: + x = block(x) + x = self.linear_2(x) + return x + + class DummyPipeline(DiffusionPipeline): model_cpu_offload_seq = "model" @@ -212,3 +247,23 @@ def test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module pipe.enable_sequential_cpu_offload() with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"): pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3) + + def test_block_level_stream_with_invocation_order_different_from_initialization_order(self): + if torch.device(torch_device).type != "cuda": + return + model = DummyModelWithMultipleBlocks( + in_features=self.in_features, + hidden_features=self.hidden_features, + out_features=self.out_features, + num_layers=self.num_layers, + num_single_layers=self.num_layers + 1, + ) + model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True) + + context = contextlib.nullcontext() + if compare_versions("diffusers", "<=", "0.33.0"): + # Will raise a device mismatch RuntimeError mentioning weights are on CPU but input is on device + context = self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device") + + with context: + model(self.input)