Skip to content

Commit a2f0db5

Browse files
[LoRA] don't break offloading for incompatible lora ckpts. (#5085)
* don't break offloading for incompatible lora ckpts. * debugging * better condition. * fix * fix * fix * fix --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 92f6693 commit a2f0db5

File tree

1 file changed

+97
-59
lines changed

1 file changed

+97
-59
lines changed

src/diffusers/loaders.py

Lines changed: 97 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414
import os
1515
import re
16-
import warnings
1716
from collections import defaultdict
1817
from contextlib import nullcontext
1918
from io import BytesIO
@@ -307,6 +306,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
307306
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
308307
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
309308
network_alphas = kwargs.pop("network_alphas", None)
309+
310+
_pipeline = kwargs.pop("_pipeline", None)
311+
310312
is_network_alphas_none = network_alphas is None
311313

312314
allow_pickle = False
@@ -460,6 +462,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
460462
load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
461463
else:
462464
lora.load_state_dict(value_dict)
465+
463466
elif is_custom_diffusion:
464467
attn_processors = {}
465468
custom_diffusion_grouped_dict = defaultdict(dict)
@@ -489,19 +492,44 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
489492
cross_attention_dim=cross_attention_dim,
490493
)
491494
attn_processors[key].load_state_dict(value_dict)
492-
493-
self.set_attn_processor(attn_processors)
494495
else:
495496
raise ValueError(
496497
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
497498
)
498499

500+
# <Unsafe code
501+
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
502+
# Now we remove any existing hooks to
503+
is_model_cpu_offload = False
504+
is_sequential_cpu_offload = False
505+
if _pipeline is not None:
506+
for _, component in _pipeline.components.items():
507+
if isinstance(component, nn.Module):
508+
if hasattr(component, "_hf_hook"):
509+
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
510+
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
511+
logger.info(
512+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
513+
)
514+
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
515+
516+
# only custom diffusion needs to set attn processors
517+
if is_custom_diffusion:
518+
self.set_attn_processor(attn_processors)
519+
499520
# set lora layers
500521
for target_module, lora_layer in lora_layers_list:
501522
target_module.set_lora_layer(lora_layer)
502523

503524
self.to(dtype=self.dtype, device=self.device)
504525

526+
# Offload back.
527+
if is_model_cpu_offload:
528+
_pipeline.enable_model_cpu_offload()
529+
elif is_sequential_cpu_offload:
530+
_pipeline.enable_sequential_cpu_offload()
531+
# Unsafe code />
532+
505533
def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
506534
is_new_lora_format = all(
507535
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
@@ -1060,41 +1088,31 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
10601088
kwargs (`dict`, *optional*):
10611089
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
10621090
"""
1063-
# Remove any existing hooks.
1064-
is_model_cpu_offload = False
1065-
is_sequential_cpu_offload = False
1066-
recurive = False
1067-
for _, component in self.components.items():
1068-
if isinstance(component, nn.Module):
1069-
if hasattr(component, "_hf_hook"):
1070-
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
1071-
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
1072-
logger.info(
1073-
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
1074-
)
1075-
recurive = is_sequential_cpu_offload
1076-
remove_hook_from_module(component, recurse=recurive)
1091+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
1092+
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
1093+
1094+
is_correct_format = all("lora" in key for key in state_dict.keys())
1095+
if not is_correct_format:
1096+
raise ValueError("Invalid LoRA checkpoint.")
10771097

10781098
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
10791099

1080-
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
10811100
self.load_lora_into_unet(
1082-
state_dict, network_alphas=network_alphas, unet=self.unet, low_cpu_mem_usage=low_cpu_mem_usage
1101+
state_dict,
1102+
network_alphas=network_alphas,
1103+
unet=self.unet,
1104+
low_cpu_mem_usage=low_cpu_mem_usage,
1105+
_pipeline=self,
10831106
)
10841107
self.load_lora_into_text_encoder(
10851108
state_dict,
10861109
network_alphas=network_alphas,
10871110
text_encoder=self.text_encoder,
10881111
lora_scale=self.lora_scale,
10891112
low_cpu_mem_usage=low_cpu_mem_usage,
1113+
_pipeline=self,
10901114
)
10911115

1092-
# Offload back.
1093-
if is_model_cpu_offload:
1094-
self.enable_model_cpu_offload()
1095-
elif is_sequential_cpu_offload:
1096-
self.enable_sequential_cpu_offload()
1097-
10981116
@classmethod
10991117
def lora_state_dict(
11001118
cls,
@@ -1391,7 +1409,7 @@ def _maybe_map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="
13911409
return new_state_dict
13921410

13931411
@classmethod
1394-
def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None):
1412+
def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, _pipeline=None):
13951413
"""
13961414
This will load the LoRA layers specified in `state_dict` into `unet`.
13971415
@@ -1433,13 +1451,22 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage
14331451
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
14341452
# contain the module names of the `unet` as its keys WITHOUT any prefix.
14351453
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
1436-
warnings.warn(warn_message)
1454+
logger.warn(warn_message)
14371455

1438-
unet.load_attn_procs(state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage)
1456+
unet.load_attn_procs(
1457+
state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline
1458+
)
14391459

14401460
@classmethod
14411461
def load_lora_into_text_encoder(
1442-
cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0, low_cpu_mem_usage=None
1462+
cls,
1463+
state_dict,
1464+
network_alphas,
1465+
text_encoder,
1466+
prefix=None,
1467+
lora_scale=1.0,
1468+
low_cpu_mem_usage=None,
1469+
_pipeline=None,
14431470
):
14441471
"""
14451472
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1549,11 +1576,15 @@ def load_lora_into_text_encoder(
15491576
low_cpu_mem_usage=low_cpu_mem_usage,
15501577
)
15511578

1552-
# set correct dtype & device
1553-
text_encoder_lora_state_dict = {
1554-
k: v.to(device=text_encoder.device, dtype=text_encoder.dtype)
1555-
for k, v in text_encoder_lora_state_dict.items()
1556-
}
1579+
is_pipeline_offloaded = _pipeline is not None and any(
1580+
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook") for c in _pipeline.components.values()
1581+
)
1582+
if is_pipeline_offloaded and low_cpu_mem_usage:
1583+
low_cpu_mem_usage = True
1584+
logger.info(
1585+
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
1586+
)
1587+
15571588
if low_cpu_mem_usage:
15581589
device = next(iter(text_encoder_lora_state_dict.values())).device
15591590
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
@@ -1569,8 +1600,33 @@ def load_lora_into_text_encoder(
15691600
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
15701601
)
15711602

1603+
# <Unsafe code
1604+
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
1605+
# Now we remove any existing hooks to
1606+
is_model_cpu_offload = False
1607+
is_sequential_cpu_offload = False
1608+
if _pipeline is not None:
1609+
for _, component in _pipeline.components.items():
1610+
if isinstance(component, torch.nn.Module):
1611+
if hasattr(component, "_hf_hook"):
1612+
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
1613+
is_sequential_cpu_offload = isinstance(
1614+
getattr(component, "_hf_hook"), AlignDevicesHook
1615+
)
1616+
logger.info(
1617+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
1618+
)
1619+
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
1620+
15721621
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
15731622

1623+
# Offload back.
1624+
if is_model_cpu_offload:
1625+
_pipeline.enable_model_cpu_offload()
1626+
elif is_sequential_cpu_offload:
1627+
_pipeline.enable_sequential_cpu_offload()
1628+
# Unsafe code />
1629+
15741630
@property
15751631
def lora_scale(self) -> float:
15761632
# property function that returns the lora scale which can be set at run time by the pipeline.
@@ -2639,31 +2695,17 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
26392695
# it here explicitly to be able to tell that it's coming from an SDXL
26402696
# pipeline.
26412697

2642-
# Remove any existing hooks.
2643-
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
2644-
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
2645-
else:
2646-
raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
2647-
2648-
is_model_cpu_offload = False
2649-
is_sequential_cpu_offload = False
2650-
for _, component in self.components.items():
2651-
if isinstance(component, torch.nn.Module):
2652-
if hasattr(component, "_hf_hook"):
2653-
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
2654-
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
2655-
logger.info(
2656-
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
2657-
)
2658-
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
2659-
2698+
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
26602699
state_dict, network_alphas = self.lora_state_dict(
26612700
pretrained_model_name_or_path_or_dict,
26622701
unet_config=self.unet.config,
26632702
**kwargs,
26642703
)
2665-
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
2704+
is_correct_format = all("lora" in key for key in state_dict.keys())
2705+
if not is_correct_format:
2706+
raise ValueError("Invalid LoRA checkpoint.")
26662707

2708+
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet, _pipeline=self)
26672709
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
26682710
if len(text_encoder_state_dict) > 0:
26692711
self.load_lora_into_text_encoder(
@@ -2672,6 +2714,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
26722714
text_encoder=self.text_encoder,
26732715
prefix="text_encoder",
26742716
lora_scale=self.lora_scale,
2717+
_pipeline=self,
26752718
)
26762719

26772720
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
@@ -2682,14 +2725,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
26822725
text_encoder=self.text_encoder_2,
26832726
prefix="text_encoder_2",
26842727
lora_scale=self.lora_scale,
2728+
_pipeline=self,
26852729
)
26862730

2687-
# Offload back.
2688-
if is_model_cpu_offload:
2689-
self.enable_model_cpu_offload()
2690-
elif is_sequential_cpu_offload:
2691-
self.enable_sequential_cpu_offload()
2692-
26932731
@classmethod
26942732
def save_lora_weights(
26952733
self,

0 commit comments

Comments
 (0)