Skip to content

Commit 954bb7d

Browse files
committed
handle .to() when group offload applied
1 parent 8804d74 commit 954bb7d

File tree

4 files changed

+80
-19
lines changed

4 files changed

+80
-19
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,3 +662,10 @@ def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn
662662
f"offloading strategy from Accelerate. If you want to apply group offloading, please "
663663
f"disable the existing offloading strategy first. Offending module: {name} ({type(submodule)})"
664664
)
665+
666+
667+
def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
668+
for submodule in module.modules():
669+
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
670+
return True
671+
return False

src/diffusers/models/modeling_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,8 +1245,21 @@ def cuda(self, *args, **kwargs):
12451245
# Adapted from `transformers`.
12461246
@wraps(torch.nn.Module.to)
12471247
def to(self, *args, **kwargs):
1248+
from ..hooks.group_offloading import _is_group_offload_enabled
1249+
1250+
device_arg_or_kwarg_present = any(isinstance(arg, torch.device) for arg in args) or "device" in kwargs
12481251
dtype_present_in_args = "dtype" in kwargs
12491252

1253+
# Try converting arguments to torch.device in case they are passed as strings
1254+
for arg in args:
1255+
if not isinstance(arg, str):
1256+
continue
1257+
try:
1258+
torch.device(arg)
1259+
device_arg_or_kwarg_present = True
1260+
except RuntimeError:
1261+
pass
1262+
12501263
if not dtype_present_in_args:
12511264
for arg in args:
12521265
if isinstance(arg, torch.dtype):
@@ -1271,6 +1284,13 @@ def to(self, *args, **kwargs):
12711284
"Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
12721285
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
12731286
)
1287+
1288+
if _is_group_offload_enabled(self) and device_arg_or_kwarg_present:
1289+
logger.warning(
1290+
f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.to()` is not supported."
1291+
)
1292+
return self
1293+
12741294
return super().to(*args, **kwargs)
12751295

12761296
# Taken from `transformers`.

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ def to(self, *args, **kwargs):
394394
)
395395

396396
device = device or device_arg
397+
device_type = torch.device(device).type if device is not None else None
397398
pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
398399

399400
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
@@ -424,7 +425,7 @@ def module_is_offloaded(module):
424425
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
425426
)
426427

427-
if device and torch.device(device).type == "cuda":
428+
if device_type == "cuda":
428429
if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
429430
raise ValueError(
430431
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
@@ -437,7 +438,7 @@ def module_is_offloaded(module):
437438

438439
# Display a warning in this case (the operation succeeds but the benefits are lost)
439440
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
440-
if pipeline_is_offloaded and device and torch.device(device).type == "cuda":
441+
if pipeline_is_offloaded and device_type == "cuda":
441442
logger.warning(
442443
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
443444
)
@@ -449,6 +450,7 @@ def module_is_offloaded(module):
449450
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
450451
for module in modules:
451452
_, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module)
453+
is_group_offloaded = self._maybe_raise_error_if_group_offload_active(module=module)
452454

453455
if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None:
454456
logger.warning(
@@ -460,11 +462,21 @@ def module_is_offloaded(module):
460462
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
461463
)
462464

465+
# Note: we also handle this as the ModelMixin level. The reason for doing it here too is that modeling
466+
# components can be from outside diffusers too, but still have group offloading enabled.
467+
if (
468+
self._maybe_raise_error_if_group_offload_active(raise_error=False, module=module)
469+
and device is not None
470+
):
471+
logger.warning(
472+
f"The module '{module.__class__.__name__}' is group offloaded and moving it to {device} via `.to()` is not supported."
473+
)
474+
463475
# This can happen for `transformer` models. CPU placement was added in
464476
# https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
465477
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
466478
module.to(device=device)
467-
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb:
479+
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb and not is_group_offloaded:
468480
module.to(device, dtype)
469481

470482
if (
@@ -1075,7 +1087,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
10751087
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
10761088
default to "cuda".
10771089
"""
1078-
self._check_group_offloading_inactive_or_raise_error()
1090+
self._maybe_raise_error_if_group_offload_active(raise_error=True)
10791091

10801092
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
10811093
if is_pipeline_device_mapped:
@@ -1188,7 +1200,7 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
11881200
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
11891201
default to "cuda".
11901202
"""
1191-
self._check_group_offloading_inactive_or_raise_error()
1203+
self._maybe_raise_error_if_group_offload_active(raise_error=True)
11921204

11931205
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
11941206
from accelerate import cpu_offload
@@ -1914,23 +1926,23 @@ def from_pipe(cls, pipeline, **kwargs):
19141926

19151927
return new_pipeline
19161928

1917-
def _check_group_offloading_inactive_or_raise_error(self) -> None:
1918-
from ..hooks import HookRegistry
1919-
from ..hooks.group_offloading import _GROUP_OFFLOADING
1929+
def _maybe_raise_error_if_group_offload_active(
1930+
self, raise_error: bool = False, module: Optional[torch.nn.Module] = None
1931+
) -> bool:
1932+
from ..hooks.group_offloading import _is_group_offload_enabled
19201933

1921-
for name, component in self.components.items():
1922-
if not isinstance(component, torch.nn.Module):
1923-
continue
1924-
for module in component.modules():
1925-
if not hasattr(module, "_diffusers_hook"):
1926-
continue
1927-
registry: HookRegistry = module._diffusers_hook
1928-
if registry.get_hook(_GROUP_OFFLOADING) is not None:
1934+
components = self.components.values() if module is None else [module]
1935+
components = [component for component in components if isinstance(component, torch.nn.Module)]
1936+
for component in components:
1937+
if _is_group_offload_enabled(component):
1938+
if raise_error:
19291939
raise ValueError(
1930-
f"You are trying to apply model/sequential CPU offloading to a pipeline that contains "
1931-
f"components with group offloading enabled. This is not supported. Please disable group "
1932-
f"offloading for the '{name}' component of the pipeline to use other offloading methods."
1940+
"You are trying to apply model/sequential CPU offloading to a pipeline that contains components "
1941+
"with group offloading enabled. This is not supported. Please disable group offloading for "
1942+
"components of the pipeline to use other offloading methods."
19331943
)
1944+
return True
1945+
return False
19341946

19351947

19361948
class StableDiffusionMixin:

tests/hooks/test_group_offloading.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from diffusers.models import ModelMixin
2121
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
22+
from diffusers.utils import get_logger
2223
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
2324

2425

@@ -153,6 +154,27 @@ def run_forward(model):
153154
# Memory assertions - offloading should reduce memory usage
154155
self.assertTrue(mem4 <= mem5 < mem2 < mem3 < mem1 < mem_baseline)
155156

157+
def test_warning_logged_if_group_offloaded_module_moved_to_cuda(self):
158+
if torch.device(torch_device).type != "cuda":
159+
return
160+
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
161+
logger = get_logger("diffusers.models.modeling_utils")
162+
logger.setLevel("INFO")
163+
with self.assertLogs(logger, level="WARNING") as cm:
164+
self.model.to(torch_device)
165+
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
166+
167+
def test_warning_logged_if_group_offloaded_pipe_moved_to_cuda(self):
168+
if torch.device(torch_device).type != "cuda":
169+
return
170+
pipe = DummyPipeline(self.model)
171+
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
172+
logger = get_logger("diffusers.pipelines.pipeline_utils")
173+
logger.setLevel("INFO")
174+
with self.assertLogs(logger, level="WARNING") as cm:
175+
pipe.to(torch_device)
176+
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
177+
156178
def test_error_raised_if_streams_used_and_no_cuda_device(self):
157179
original_is_available = torch.cuda.is_available
158180
torch.cuda.is_available = lambda: False

0 commit comments

Comments
 (0)