Skip to content

Commit da88c33

Browse files
committed
refactor some repeated code
1 parent ba6c4a8 commit da88c33

File tree

3 files changed

+23
-16
lines changed

3 files changed

+23
-16
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,3 +669,10 @@ def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
669669
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
670670
return True
671671
return False
672+
673+
674+
def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
675+
for submodule in module.modules():
676+
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
677+
return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device
678+
raise ValueError("Group offloading is not enabled for the provided module.")

src/diffusers/models/modeling_utils.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,17 @@
8686

8787

8888
def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
89+
from ..hooks.group_offloading import _get_group_onload_device
90+
8991
try:
90-
if hasattr(parameter, "_diffusers_hook"):
91-
for submodule in parameter.modules():
92-
if not hasattr(submodule, "_diffusers_hook"):
93-
continue
94-
registry = parameter._diffusers_hook
95-
hook = registry.get_hook("group_offloading")
96-
if hook is not None:
97-
return hook.group.onload_device
92+
# Try to get the onload device from the group offloading hook
93+
return _get_group_onload_device(parameter)
94+
except ValueError:
95+
pass
9896

97+
try:
98+
# If the onload device is not available due to no group offloading hooks, try to get the device
99+
# from the first parameter or buffer
99100
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
100101
return next(parameters_and_buffers).device
101102
except StopIteration:

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ def module_is_offloaded(module):
462462
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
463463
)
464464

465-
# Note: we also handle this as the ModelMixin level. The reason for doing it here too is that modeling
465+
# Note: we also handle this at the ModelMixin level. The reason for doing it here too is that modeling
466466
# components can be from outside diffusers too, but still have group offloading enabled.
467467
if (
468468
self._maybe_raise_error_if_group_offload_active(raise_error=False, module=module)
@@ -1035,19 +1035,18 @@ def _execution_device(self):
10351035
[`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
10361036
Accelerate's module hooks.
10371037
"""
1038+
from ..hooks.group_offloading import _get_group_onload_device
1039+
10381040
# When apply group offloading at the leaf_level, we're in the same situation as accelerate's sequential
10391041
# offloading. We need to return the onload device of the group offloading hooks so that the intermediates
10401042
# required for computation (latents, prompt embeddings, etc.) can be created on the correct device.
10411043
for name, model in self.components.items():
10421044
if not isinstance(model, torch.nn.Module):
10431045
continue
1044-
for submodule in model.modules():
1045-
if not hasattr(submodule, "_diffusers_hook"):
1046-
continue
1047-
registry = submodule._diffusers_hook
1048-
hook = registry.get_hook("group_offloading")
1049-
if hook is not None:
1050-
return hook.group.onload_device
1046+
try:
1047+
return _get_group_onload_device(model)
1048+
except ValueError:
1049+
pass
10511050

10521051
for name, model in self.components.items():
10531052
if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload:

0 commit comments

Comments
 (0)