Skip to content

Commit 1f7939d

Browse files
authored
Merge branch 'main' into main
2 parents 4845478 + f064b3b commit 1f7939d

File tree

15 files changed

+554
-238
lines changed

15 files changed

+554
-238
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 140 additions & 153 deletions
Large diffs are not rendered by default.

src/diffusers/loaders/lora_base.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from huggingface_hub import model_info
2626
from huggingface_hub.constants import HF_HUB_OFFLINE
2727

28+
from ..hooks.group_offloading import _is_group_offload_enabled, _maybe_remove_and_reapply_group_offloading
2829
from ..models.modeling_utils import ModelMixin, load_state_dict
2930
from ..utils import (
3031
USE_PEFT_BACKEND,
@@ -391,7 +392,9 @@ def _load_lora_into_text_encoder(
391392
adapter_name = get_adapter_name(text_encoder)
392393

393394
# <Unsafe code
394-
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
395+
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = _func_optionally_disable_offloading(
396+
_pipeline
397+
)
395398
# inject LoRA layers and load the state dict
396399
# in transformers we automatically check whether the adapter name is already in use or not
397400
text_encoder.load_adapter(
@@ -410,6 +413,10 @@ def _load_lora_into_text_encoder(
410413
_pipeline.enable_model_cpu_offload()
411414
elif is_sequential_cpu_offload:
412415
_pipeline.enable_sequential_cpu_offload()
416+
elif is_group_offload:
417+
for component in _pipeline.components.values():
418+
if isinstance(component, torch.nn.Module):
419+
_maybe_remove_and_reapply_group_offloading(component)
413420
# Unsafe code />
414421

415422
if prefix is not None and not state_dict:
@@ -433,30 +440,36 @@ def _func_optionally_disable_offloading(_pipeline):
433440
434441
Returns:
435442
tuple:
436-
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
443+
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True.
437444
"""
438445
is_model_cpu_offload = False
439446
is_sequential_cpu_offload = False
447+
is_group_offload = False
440448

441449
if _pipeline is not None and _pipeline.hf_device_map is None:
442450
for _, component in _pipeline.components.items():
443-
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
444-
if not is_model_cpu_offload:
445-
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
446-
if not is_sequential_cpu_offload:
447-
is_sequential_cpu_offload = (
448-
isinstance(component._hf_hook, AlignDevicesHook)
449-
or hasattr(component._hf_hook, "hooks")
450-
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
451-
)
451+
if not isinstance(component, nn.Module):
452+
continue
453+
is_group_offload = is_group_offload or _is_group_offload_enabled(component)
454+
if not hasattr(component, "_hf_hook"):
455+
continue
456+
is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload)
457+
is_sequential_cpu_offload = is_sequential_cpu_offload or (
458+
isinstance(component._hf_hook, AlignDevicesHook)
459+
or hasattr(component._hf_hook, "hooks")
460+
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
461+
)
452462

453-
logger.info(
454-
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
455-
)
456-
if is_sequential_cpu_offload or is_model_cpu_offload:
457-
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
463+
if is_sequential_cpu_offload or is_model_cpu_offload:
464+
logger.info(
465+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
466+
)
467+
for _, component in _pipeline.components.items():
468+
if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"):
469+
continue
470+
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
458471

459-
return (is_model_cpu_offload, is_sequential_cpu_offload)
472+
return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)
460473

461474

462475
class LoraBaseMixin:

src/diffusers/loaders/peft.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import safetensors
2323
import torch
2424

25+
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
2526
from ..utils import (
2627
MIN_PEFT_VERSION,
2728
USE_PEFT_BACKEND,
@@ -243,20 +244,29 @@ def load_lora_adapter(
243244
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
244245
}
245246

246-
# create LoraConfig
247-
lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank)
248-
249247
# adapter_name
250248
if adapter_name is None:
251249
adapter_name = get_adapter_name(self)
252250

251+
# create LoraConfig
252+
lora_config = _create_lora_config(
253+
state_dict,
254+
network_alphas,
255+
metadata,
256+
rank,
257+
model_state_dict=self.state_dict(),
258+
adapter_name=adapter_name,
259+
)
260+
253261
# <Unsafe code
254262
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
255263
# Now we remove any existing hooks to `_pipeline`.
256264

257265
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
258266
# otherwise loading LoRA weights will lead to an error.
259-
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
267+
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
268+
_pipeline
269+
)
260270
peft_kwargs = {}
261271
if is_peft_version(">=", "0.13.1"):
262272
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
@@ -347,6 +357,10 @@ def map_state_dict_for_hotswap(sd):
347357
_pipeline.enable_model_cpu_offload()
348358
elif is_sequential_cpu_offload:
349359
_pipeline.enable_sequential_cpu_offload()
360+
elif is_group_offload:
361+
for component in _pipeline.components.values():
362+
if isinstance(component, torch.nn.Module):
363+
_maybe_remove_and_reapply_group_offloading(component)
350364
# Unsafe code />
351365

352366
if prefix is not None and not state_dict:
@@ -686,6 +700,10 @@ def unload_lora(self):
686700
recurse_remove_peft_layers(self)
687701
if hasattr(self, "peft_config"):
688702
del self.peft_config
703+
if hasattr(self, "_hf_peft_config_loaded"):
704+
self._hf_peft_config_loaded = None
705+
706+
_maybe_remove_and_reapply_group_offloading(self)
689707

690708
def disable_lora(self):
691709
"""

src/diffusers/loaders/unet.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import torch.nn.functional as F
2323
from huggingface_hub.utils import validate_hf_hub_args
2424

25+
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
2526
from ..models.embeddings import (
2627
ImageProjection,
2728
IPAdapterFaceIDImageProjection,
@@ -203,6 +204,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
203204
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
204205
is_model_cpu_offload = False
205206
is_sequential_cpu_offload = False
207+
is_group_offload = False
206208

207209
if is_lora:
208210
deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
@@ -211,7 +213,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
211213
if is_custom_diffusion:
212214
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
213215
elif is_lora:
214-
is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora(
216+
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._process_lora(
215217
state_dict=state_dict,
216218
unet_identifier_key=self.unet_name,
217219
network_alphas=network_alphas,
@@ -230,7 +232,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
230232

231233
# For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
232234
if is_custom_diffusion and _pipeline is not None:
233-
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline)
235+
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
236+
_pipeline=_pipeline
237+
)
234238

235239
# only custom diffusion needs to set attn processors
236240
self.set_attn_processor(attn_processors)
@@ -241,6 +245,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
241245
_pipeline.enable_model_cpu_offload()
242246
elif is_sequential_cpu_offload:
243247
_pipeline.enable_sequential_cpu_offload()
248+
elif is_group_offload:
249+
for component in _pipeline.components.values():
250+
if isinstance(component, torch.nn.Module):
251+
_maybe_remove_and_reapply_group_offloading(component)
244252
# Unsafe code />
245253

246254
def _process_custom_diffusion(self, state_dict):
@@ -307,6 +315,7 @@ def _process_lora(
307315

308316
is_model_cpu_offload = False
309317
is_sequential_cpu_offload = False
318+
is_group_offload = False
310319
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
311320

312321
if len(state_dict_to_be_used) > 0:
@@ -356,7 +365,9 @@ def _process_lora(
356365

357366
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
358367
# otherwise loading LoRA weights will lead to an error
359-
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
368+
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
369+
_pipeline
370+
)
360371
peft_kwargs = {}
361372
if is_peft_version(">=", "0.13.1"):
362373
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
@@ -389,7 +400,7 @@ def _process_lora(
389400
if warn_msg:
390401
logger.warning(warn_msg)
391402

392-
return is_model_cpu_offload, is_sequential_cpu_offload
403+
return is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload
393404

394405
@classmethod
395406
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading

src/diffusers/loaders/unet_loader_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import copy
1515
from typing import TYPE_CHECKING, Dict, List, Union
1616

17+
from torch import nn
18+
1719
from ..utils import logging
1820

1921

@@ -52,7 +54,7 @@ def _maybe_expand_lora_scales(
5254
weight_for_adapter,
5355
blocks_with_transformer,
5456
transformer_per_block,
55-
unet.state_dict(),
57+
model=unet,
5658
default_scale=default_scale,
5759
)
5860
for weight_for_adapter in weight_scales
@@ -65,7 +67,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
6567
scales: Union[float, Dict],
6668
blocks_with_transformer: Dict[str, int],
6769
transformer_per_block: Dict[str, int],
68-
state_dict: None,
70+
model: nn.Module,
6971
default_scale: float = 1.0,
7072
):
7173
"""
@@ -154,6 +156,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
154156

155157
del scales[updown]
156158

159+
state_dict = model.state_dict()
157160
for layer in scales.keys():
158161
if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()):
159162
raise ValueError(

src/diffusers/schedulers/scheduling_scm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,6 @@ def set_timesteps(
168168
else:
169169
# max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here
170170
self.timesteps = torch.linspace(max_timesteps, 0, num_inference_steps + 1, device=device).float()
171-
print(f"Set timesteps: {self.timesteps}")
172171

173172
self._step_index = None
174173
self._begin_index = None

src/diffusers/utils/peft_utils.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
150150
module.set_scale(adapter_name, 1.0)
151151

152152

153-
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
153+
def get_peft_kwargs(
154+
rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None
155+
):
154156
rank_pattern = {}
155157
alpha_pattern = {}
156158
r = lora_alpha = list(rank_dict.values())[0]
@@ -180,7 +182,6 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
180182
else:
181183
lora_alpha = set(network_alpha_dict.values()).pop()
182184

183-
# layer names without the Diffusers specific
184185
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
185186
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
186187
# for now we know that the "bias" keys are only associated with `lora_B`.
@@ -195,6 +196,21 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
195196
"use_dora": use_dora,
196197
"lora_bias": lora_bias,
197198
}
199+
200+
# Example: try load FusionX LoRA into Wan VACE
201+
exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name)
202+
if exclude_modules:
203+
if not is_peft_version(">=", "0.14.0"):
204+
msg = """
205+
It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft`
206+
version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U
207+
peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue -
208+
https://github.com/huggingface/diffusers/issues/new
209+
"""
210+
logger.debug(msg)
211+
else:
212+
lora_config_kwargs.update({"exclude_modules": exclude_modules})
213+
198214
return lora_config_kwargs
199215

200216

@@ -294,19 +310,20 @@ def check_peft_version(min_version: str) -> None:
294310

295311

296312
def _create_lora_config(
297-
state_dict,
298-
network_alphas,
299-
metadata,
300-
rank_pattern_dict,
301-
is_unet: bool = True,
313+
state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None
302314
):
303315
from peft import LoraConfig
304316

305317
if metadata is not None:
306318
lora_config_kwargs = metadata
307319
else:
308320
lora_config_kwargs = get_peft_kwargs(
309-
rank_pattern_dict, network_alpha_dict=network_alphas, peft_state_dict=state_dict, is_unet=is_unet
321+
rank_pattern_dict,
322+
network_alpha_dict=network_alphas,
323+
peft_state_dict=state_dict,
324+
is_unet=is_unet,
325+
model_state_dict=model_state_dict,
326+
adapter_name=adapter_name,
310327
)
311328

312329
_maybe_raise_error_for_ambiguous_keys(lora_config_kwargs)
@@ -371,3 +388,27 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
371388

372389
if warn_msg:
373390
logger.warning(warn_msg)
391+
392+
393+
def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None):
394+
"""
395+
Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the
396+
`model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it
397+
doesn't exist in `peft_state_dict`.
398+
"""
399+
if model_state_dict is None:
400+
return
401+
all_modules = set()
402+
string_to_replace = f"{adapter_name}." if adapter_name else ""
403+
404+
for name in model_state_dict.keys():
405+
if string_to_replace:
406+
name = name.replace(string_to_replace, "")
407+
if "." in name:
408+
module_name = name.rsplit(".", 1)[0]
409+
all_modules.add(module_name)
410+
411+
target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()}
412+
exclude_modules = list(all_modules - target_modules_set)
413+
414+
return exclude_modules

tests/lora/test_lora_layers_cogvideox.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import unittest
1717

1818
import torch
19+
from parameterized import parameterized
1920
from transformers import AutoTokenizer, T5EncoderModel
2021

2122
from diffusers import (
@@ -28,6 +29,7 @@
2829
from diffusers.utils.testing_utils import (
2930
floats_tensor,
3031
require_peft_backend,
32+
require_torch_accelerator,
3133
)
3234

3335

@@ -127,6 +129,13 @@ def test_simple_inference_with_text_denoiser_lora_unfused(self):
127129
def test_lora_scale_kwargs_match_fusion(self):
128130
super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3)
129131

132+
@parameterized.expand([("block_level", True), ("leaf_level", False)])
133+
@require_torch_accelerator
134+
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
135+
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
136+
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
137+
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
138+
130139
@unittest.skip("Not supported in CogVideoX.")
131140
def test_simple_inference_with_text_denoiser_block_scale(self):
132141
pass

0 commit comments

Comments
 (0)