Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
293 changes: 140 additions & 153 deletions src/diffusers/hooks/group_offloading.py

Large diffs are not rendered by default.

47 changes: 30 additions & 17 deletions src/diffusers/loaders/lora_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from huggingface_hub import model_info
from huggingface_hub.constants import HF_HUB_OFFLINE

from ..hooks.group_offloading import _is_group_offload_enabled, _maybe_remove_and_reapply_group_offloading
from ..models.modeling_utils import ModelMixin, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
Expand Down Expand Up @@ -391,7 +392,9 @@ def _load_lora_into_text_encoder(
adapter_name = get_adapter_name(text_encoder)

# <Unsafe code
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = _func_optionally_disable_offloading(
_pipeline
)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
Expand All @@ -410,6 +413,10 @@ def _load_lora_into_text_encoder(
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
elif is_group_offload:
for component in _pipeline.components.values():
if isinstance(component, torch.nn.Module):
_maybe_remove_and_reapply_group_offloading(component)
# Unsafe code />

if prefix is not None and not state_dict:
Expand All @@ -433,30 +440,36 @@ def _func_optionally_disable_offloading(_pipeline):

Returns:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True.
"""
is_model_cpu_offload = False
is_sequential_cpu_offload = False
is_group_offload = False

if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
if not isinstance(component, nn.Module):
continue
is_group_offload = is_group_offload or _is_group_offload_enabled(component)
if not hasattr(component, "_hf_hook"):
continue
is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload)
is_sequential_cpu_offload = is_sequential_cpu_offload or (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)

logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
if is_sequential_cpu_offload or is_model_cpu_offload:
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
if is_sequential_cpu_offload or is_model_cpu_offload:
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
for _, component in _pipeline.components.items():
if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"):
continue
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)

return (is_model_cpu_offload, is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)


class LoraBaseMixin:
Expand Down
11 changes: 10 additions & 1 deletion src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import safetensors
import torch

from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
from ..utils import (
MIN_PEFT_VERSION,
USE_PEFT_BACKEND,
Expand Down Expand Up @@ -256,7 +257,9 @@ def load_lora_adapter(

# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error.
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
_pipeline
)
peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
Expand Down Expand Up @@ -347,6 +350,10 @@ def map_state_dict_for_hotswap(sd):
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
elif is_group_offload:
for component in _pipeline.components.values():
if isinstance(component, torch.nn.Module):
_maybe_remove_and_reapply_group_offloading(component)
# Unsafe code />

if prefix is not None and not state_dict:
Expand Down Expand Up @@ -687,6 +694,8 @@ def unload_lora(self):
if hasattr(self, "peft_config"):
del self.peft_config

_maybe_remove_and_reapply_group_offloading(self)

def disable_lora(self):
"""
Disables the active LoRA layers of the underlying model.
Expand Down
19 changes: 15 additions & 4 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args

from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
from ..models.embeddings import (
ImageProjection,
IPAdapterFaceIDImageProjection,
Expand Down Expand Up @@ -203,6 +204,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
is_model_cpu_offload = False
is_sequential_cpu_offload = False
is_group_offload = False

if is_lora:
deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
Expand All @@ -211,7 +213,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
if is_custom_diffusion:
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
elif is_lora:
is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora(
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._process_lora(
state_dict=state_dict,
unet_identifier_key=self.unet_name,
network_alphas=network_alphas,
Expand All @@ -230,7 +232,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict

# For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
if is_custom_diffusion and _pipeline is not None:
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline)
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
_pipeline=_pipeline
)

# only custom diffusion needs to set attn processors
self.set_attn_processor(attn_processors)
Expand All @@ -241,6 +245,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
elif is_group_offload:
for component in _pipeline.components.values():
if isinstance(component, torch.nn.Module):
_maybe_remove_and_reapply_group_offloading(component)
# Unsafe code />

def _process_custom_diffusion(self, state_dict):
Expand Down Expand Up @@ -307,6 +315,7 @@ def _process_lora(

is_model_cpu_offload = False
is_sequential_cpu_offload = False
is_group_offload = False
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict

if len(state_dict_to_be_used) > 0:
Expand Down Expand Up @@ -356,7 +365,9 @@ def _process_lora(

# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
_pipeline
)
peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
Expand Down Expand Up @@ -389,7 +400,7 @@ def _process_lora(
if warn_msg:
logger.warning(warn_msg)

return is_model_cpu_offload, is_sequential_cpu_offload
return is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload

@classmethod
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
Expand Down
9 changes: 9 additions & 0 deletions tests/lora/test_lora_layers_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import unittest

import torch
from parameterized import parameterized
from transformers import AutoTokenizer, T5EncoderModel

from diffusers import (
Expand All @@ -28,6 +29,7 @@
from diffusers.utils.testing_utils import (
floats_tensor,
require_peft_backend,
require_torch_accelerator,
)


Expand Down Expand Up @@ -127,6 +129,13 @@ def test_simple_inference_with_text_denoiser_lora_unfused(self):
def test_lora_scale_kwargs_match_fusion(self):
super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3)

@parameterized.expand([("block_level", True), ("leaf_level", False)])
@require_torch_accelerator
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
Comment on lines +132 to +137
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do this or we could detect if the test class is either of CogView4 or CogVideoX and use pytest.skip(). Upto you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer the base test class not having any information about the model test classes that derive it. Current implementation will work for any model that overrides the test, so also much cleaner


@unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
Expand Down
16 changes: 15 additions & 1 deletion tests/lora/test_lora_layers_cogview4.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,17 @@

import numpy as np
import torch
from parameterized import parameterized
from transformers import AutoTokenizer, GlmModel

from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device
from diffusers.utils.testing_utils import (
floats_tensor,
require_peft_backend,
require_torch_accelerator,
skip_mps,
torch_device,
)


sys.path.append(".")
Expand Down Expand Up @@ -141,6 +148,13 @@ def test_simple_inference_save_pretrained(self):
"Loading from saved checkpoints should give same results.",
)

@parameterized.expand([("block_level", True), ("leaf_level", False)])
@require_torch_accelerator
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)

@unittest.skip("Not supported in CogView4.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
Expand Down
71 changes: 71 additions & 0 deletions tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
is_torch_version,
require_peft_backend,
require_peft_version_greater,
require_torch_accelerator,
require_transformers_version_greater,
skip_mps,
torch_device,
Expand Down Expand Up @@ -2355,3 +2356,73 @@ def test_inference_load_delete_load_adapters(self):
pipe.load_lora_weights(tmpdirname)
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))

def _test_group_offloading_inference_denoiser(self, offload_type, use_stream):
from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook

onload_device = torch_device
offload_device = torch.device("cpu")

components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))

components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet

pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
check_if_lora_correctly_set(denoiser)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

# Test group offloading with load_lora_weights
denoiser.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type=offload_type,
num_blocks_per_group=1,
use_stream=use_stream,
)
group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser)
self.assertTrue(group_offload_hook_1 is not None)
output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]

# Test group offloading after removing the lora
pipe.unload_lora_weights()
group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser)
self.assertTrue(group_offload_hook_2 is not None)
output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841

# Add the lora again and check if group offloading works
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
check_if_lora_correctly_set(denoiser)
group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser)
self.assertTrue(group_offload_hook_3 is not None)
output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]

self.assertTrue(np.allclose(output_1, output_3, atol=1e-3, rtol=1e-3))

@parameterized.expand([("block_level", True), ("leaf_level", False), ("leaf_level", True)])
@require_torch_accelerator
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
for cls in inspect.getmro(self.__class__):
if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests:
# Skip this test if it is overwritten by child class. We need to do this because parameterized
# materializes the test methods on invocation which cannot be overridden.
return
self._test_group_offloading_inference_denoiser(offload_type, use_stream)