Skip to content

Commit 8804d74

Browse files
committed
raise errors if multiple offloading strategies used; add relevant tests
1 parent 840576a commit 8804d74

File tree

3 files changed

+78
-0
lines changed

3 files changed

+78
-0
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323

2424
if is_accelerate_available():
25+
from accelerate.hooks import AlignDevicesHook, CpuOffload
2526
from accelerate.utils import send_to_device
2627

2728

@@ -341,6 +342,8 @@ def apply_group_offloading(
341342
else:
342343
raise ValueError("Using streams for data transfer requires a CUDA device.")
343344

345+
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
346+
344347
if offload_type == "block_level":
345348
if num_blocks_per_group is None:
346349
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
@@ -645,3 +648,17 @@ def _find_parent_module_in_module_dict(name: str, module_dict: Dict[str, torch.n
645648
return parent_name
646649
atoms.pop()
647650
return ""
651+
652+
653+
def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn.Module) -> None:
654+
if not is_accelerate_available():
655+
return
656+
for name, submodule in module.named_modules():
657+
if not hasattr(submodule, "_hf_hook"):
658+
continue
659+
if isinstance(submodule._hf_hook, (AlignDevicesHook, CpuOffload)):
660+
raise ValueError(
661+
f"Cannot apply group offloading to a module that is already applying an alternative "
662+
f"offloading strategy from Accelerate. If you want to apply group offloading, please "
663+
f"disable the existing offloading strategy first. Offending module: {name} ({type(submodule)})"
664+
)

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,6 +1075,8 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
10751075
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
10761076
default to "cuda".
10771077
"""
1078+
self._check_group_offloading_inactive_or_raise_error()
1079+
10781080
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
10791081
if is_pipeline_device_mapped:
10801082
raise ValueError(
@@ -1186,6 +1188,8 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
11861188
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
11871189
default to "cuda".
11881190
"""
1191+
self._check_group_offloading_inactive_or_raise_error()
1192+
11891193
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
11901194
from accelerate import cpu_offload
11911195
else:
@@ -1910,6 +1914,24 @@ def from_pipe(cls, pipeline, **kwargs):
19101914

19111915
return new_pipeline
19121916

1917+
def _check_group_offloading_inactive_or_raise_error(self) -> None:
1918+
from ..hooks import HookRegistry
1919+
from ..hooks.group_offloading import _GROUP_OFFLOADING
1920+
1921+
for name, component in self.components.items():
1922+
if not isinstance(component, torch.nn.Module):
1923+
continue
1924+
for module in component.modules():
1925+
if not hasattr(module, "_diffusers_hook"):
1926+
continue
1927+
registry: HookRegistry = module._diffusers_hook
1928+
if registry.get_hook(_GROUP_OFFLOADING) is not None:
1929+
raise ValueError(
1930+
f"You are trying to apply model/sequential CPU offloading to a pipeline that contains "
1931+
f"components with group offloading enabled. This is not supported. Please disable group "
1932+
f"offloading for the '{name}' component of the pipeline to use other offloading methods."
1933+
)
1934+
19131935

19141936
class StableDiffusionMixin:
19151937
r"""

tests/hooks/test_group_offloading.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919

2020
from diffusers.models import ModelMixin
21+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
2122
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
2223

2324

@@ -56,6 +57,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5657
return x
5758

5859

60+
class DummyPipeline(DiffusionPipeline):
61+
model_cpu_offload_seq = "model"
62+
63+
def __init__(self, model: torch.nn.Module) -> None:
64+
super().__init__()
65+
66+
self.register_modules(model=model)
67+
68+
def __call__(self, x: torch.Tensor) -> torch.Tensor:
69+
for _ in range(2):
70+
x = x + 0.1 * self.model(x)
71+
return x
72+
73+
5974
@require_torch_gpu
6075
class GroupOffloadTests(unittest.TestCase):
6176
in_features = 64
@@ -151,3 +166,27 @@ def test_error_raised_if_supports_group_offloading_false(self):
151166
self.model._supports_group_offloading = False
152167
with self.assertRaisesRegex(ValueError, "does not support group offloading"):
153168
self.model.enable_group_offload(onload_device=torch.device("cuda"))
169+
170+
def test_error_raised_if_model_offloading_applied_on_group_offloaded_module(self):
171+
pipe = DummyPipeline(self.model)
172+
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
173+
with self.assertRaisesRegex(ValueError, "You are trying to apply model/sequential CPU offloading"):
174+
pipe.enable_model_cpu_offload()
175+
176+
def test_error_raised_if_sequential_offloading_applied_on_group_offloaded_module(self):
177+
pipe = DummyPipeline(self.model)
178+
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
179+
with self.assertRaisesRegex(ValueError, "You are trying to apply model/sequential CPU offloading"):
180+
pipe.enable_sequential_cpu_offload()
181+
182+
def test_error_raised_if_group_offloading_applied_on_model_offloaded_module(self):
183+
pipe = DummyPipeline(self.model)
184+
pipe.enable_model_cpu_offload()
185+
with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"):
186+
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
187+
188+
def test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module(self):
189+
pipe = DummyPipeline(self.model)
190+
pipe.enable_sequential_cpu_offload()
191+
with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"):
192+
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)

0 commit comments

Comments
 (0)