Skip to content

Commit 583a7e9

Browse files
committed
attemp to simplfy & correct to
1 parent 42d3a6a commit 583a7e9

File tree

1 file changed

+26
-32
lines changed

1 file changed

+26
-32
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -391,49 +391,34 @@ def to(self, *args, **kwargs):
391391
pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
392392

393393
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
394-
def module_is_sequentially_offloaded(module):
394+
pipeline_is_sequentially_offloaded = hasattr(self, "_all_sequential_hooks") and self._all_sequential_hooks is not None and len(self._all_sequential_hooks) > 0
395+
pipeline_is_offloaded = hasattr(self, "_all_hooks") and self._all_hooks is not None and len(self._all_hooks) > 0
396+
pipeline_is_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1
397+
def module_has_hooks(module):
395398
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
396399
return False
397-
398-
return hasattr(module, "_hf_hook") and (
399-
isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
400-
or hasattr(module._hf_hook, "hooks")
401-
and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook)
402-
)
403-
404-
def module_is_offloaded(module):
405-
if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
406-
return False
407-
408-
return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload)
409-
410-
# .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer
411-
pipeline_is_sequentially_offloaded = any(
412-
module_is_sequentially_offloaded(module) for _, module in self.components.items()
413-
)
400+
return hasattr(module, "_hf_hook") and module._hf_hook is not None
401+
pipeline_has_hooks = any(module_has_hooks(module) for _, module in self.components.items())
414402
if device and torch.device(device).type == "cuda":
415-
if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
403+
if pipeline_is_offloaded:
416404
raise ValueError(
417-
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
405+
"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
406+
)
407+
if pipeline_is_sequentially_offloaded:
408+
raise ValueError(
409+
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now manually moving the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
418410
)
419411
# PR: https://github.com/huggingface/accelerate/pull/3223/
420-
elif pipeline_has_bnb and is_accelerate_version("<", "1.1.0.dev0"):
412+
if pipeline_has_hooks and pipeline_has_bnb and is_accelerate_version("<", "1.1.0.dev0"):
421413
raise ValueError(
422414
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
423415
)
424416

425-
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
426-
if is_pipeline_device_mapped:
417+
if pipeline_is_device_mapped:
427418
raise ValueError(
428419
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`."
429420
)
430421

431-
# Display a warning in this case (the operation succeeds but the benefits are lost)
432-
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
433-
if pipeline_is_offloaded and device and torch.device(device).type == "cuda":
434-
logger.warning(
435-
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
436-
)
437422

438423
module_names, _ = self._get_signature_keys(self)
439424
modules = [getattr(self, n, None) for n in module_names]
@@ -452,13 +437,18 @@ def module_is_offloaded(module):
452437
logger.warning(
453438
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
454439
)
440+
441+
is_device_mapped = module_has_hooks(module) and hasattr(module, "hf_device_map") and module.hf_device_map is not None
455442

456443
# This can happen for `transformer` models. CPU placement was added in
457444
# https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
458445
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
459446
module.to(device=device)
460447
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb:
461-
module.to(device, dtype)
448+
if is_device_mapped:
449+
logger.warning(f"{module.__class__.__name__} is has a device map {module.hf_device_map} and will not be moved to {device}.")
450+
else:
451+
module.to(device, dtype)
462452

463453
if (
464454
module.dtype == torch.float16
@@ -1014,7 +1004,10 @@ def remove_all_hooks(self):
10141004
for _, model in self.components.items():
10151005
if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"):
10161006
accelerate.hooks.remove_hook_from_module(model, recurse=True)
1017-
self._all_hooks = []
1007+
if hasattr(self, "_all_hooks"):
1008+
self._all_hooks = []
1009+
if hasattr(self, "_all_sequential_hooks"):
1010+
self._all_sequential_hooks = []
10181011

10191012
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
10201013
r"""
@@ -1166,17 +1159,18 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
11661159
if hasattr(device_mod, "empty_cache") and device_mod.is_available():
11671160
device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
11681161

1162+
self._all_sequential_hooks = []
11691163
for name, model in self.components.items():
11701164
if not isinstance(model, torch.nn.Module):
11711165
continue
11721166

11731167
if name in self._exclude_from_cpu_offload:
11741168
model.to(device)
11751169
else:
1176-
# make sure to offload buffers if not all high level weights
11771170
# are of type nn.Module
11781171
offload_buffers = len(model._parameters) > 0
11791172
cpu_offload(model, device, offload_buffers=offload_buffers)
1173+
self.all_sequential_hooks.append(model._hf_hook)
11801174

11811175
def reset_device_map(self):
11821176
r"""

0 commit comments

Comments
 (0)