Skip to content

Commit 4304a6d

Browse files
committed
change to a different approach.
1 parent 2bf7fde commit 4304a6d

File tree

5 files changed

+79
-59
lines changed

5 files changed

+79
-59
lines changed

src/diffusers/loaders/lora_base.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
set_adapter_layers,
4747
set_weights_and_activate_adapters,
4848
)
49-
from ..utils.state_dict_utils import _maybe_populate_state_dict_with_metadata
49+
from ..utils.state_dict_utils import _load_sft_state_dict_metadata
5050

5151

5252
if is_transformers_available():
@@ -209,6 +209,7 @@ def _fetch_state_dict(
209209
subfolder,
210210
user_agent,
211211
allow_pickle,
212+
metadata=None,
212213
):
213214
model_file = None
214215
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
@@ -240,13 +241,14 @@ def _fetch_state_dict(
240241
user_agent=user_agent,
241242
)
242243
state_dict = safetensors.torch.load_file(model_file, device="cpu")
243-
state_dict = _maybe_populate_state_dict_with_metadata(state_dict, model_file)
244+
metadata = _load_sft_state_dict_metadata(model_file)
244245

245246
except (IOError, safetensors.SafetensorError) as e:
246247
if not allow_pickle:
247248
raise e
248249
# try loading non-safetensors weights
249250
model_file = None
251+
metadata = None
250252
pass
251253

252254
if model_file is None:
@@ -267,10 +269,11 @@ def _fetch_state_dict(
267269
user_agent=user_agent,
268270
)
269271
state_dict = load_state_dict(model_file)
272+
metadata = None
270273
else:
271274
state_dict = pretrained_model_name_or_path_or_dict
272275

273-
return state_dict
276+
return state_dict, metadata
274277

275278

276279
def _best_guess_weight_name(
@@ -312,6 +315,11 @@ def _best_guess_weight_name(
312315
return weight_name
313316

314317

318+
def _pack_sd_with_prefix(state_dict, prefix):
319+
sd_with_prefix = {f"{prefix}.{key}": value for key, value in state_dict.items()}
320+
return sd_with_prefix
321+
322+
315323
def _load_lora_into_text_encoder(
316324
state_dict,
317325
network_alphas,
@@ -320,13 +328,17 @@ def _load_lora_into_text_encoder(
320328
lora_scale=1.0,
321329
text_encoder_name="text_encoder",
322330
adapter_name=None,
331+
metadata=None,
323332
_pipeline=None,
324333
low_cpu_mem_usage=False,
325334
hotswap: bool = False,
326335
):
327336
if not USE_PEFT_BACKEND:
328337
raise ValueError("PEFT backend is required for this method.")
329338

339+
if network_alphas and metadata:
340+
raise ValueError("Both `network_alphas` and `metadata` cannot be specified.")
341+
330342
peft_kwargs = {}
331343
if low_cpu_mem_usage:
332344
if not is_peft_version(">=", "0.13.1"):
@@ -353,13 +365,10 @@ def _load_lora_into_text_encoder(
353365
raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.")
354366

355367
# Load the layers corresponding to text encoder and make necessary adjustments.
356-
metadata = None
357-
if LORA_ADAPTER_METADATA_KEY in state_dict:
358-
metadata = state_dict[LORA_ADAPTER_METADATA_KEY]
359368
if prefix is not None:
360369
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
361-
if metadata is not None:
362-
state_dict[LORA_ADAPTER_METADATA_KEY] = metadata
370+
if metadata is not None:
371+
metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
363372

364373
if len(state_dict) > 0:
365374
logger.info(f"Loading {prefix}.")
@@ -387,7 +396,10 @@ def _load_lora_into_text_encoder(
387396
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
388397
network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}
389398

390-
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False, prefix=prefix)
399+
if metadata is not None:
400+
lora_config_kwargs = metadata
401+
else:
402+
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False, prefix=prefix)
391403

392404
if "use_dora" in lora_config_kwargs:
393405
if lora_config_kwargs["use_dora"]:
@@ -885,8 +897,7 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
885897
@staticmethod
886898
def pack_weights(layers, prefix):
887899
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
888-
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
889-
return layers_state_dict
900+
return _pack_sd_with_prefix(layers_weights, prefix)
890901

891902
@staticmethod
892903
def write_lora_layers(
@@ -917,7 +928,9 @@ def save_function(weights, filename):
917928
for key, value in lora_adapter_metadata.items():
918929
if isinstance(value, set):
919930
lora_adapter_metadata[key] = list(value)
920-
metadata["lora_adapter_metadata"] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
931+
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(
932+
lora_adapter_metadata, indent=2, sort_keys=True
933+
)
921934

922935
return safetensors.torch.save_file(weights, filename, metadata=metadata)
923936

src/diffusers/loaders/lora_pipeline.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
LoraBaseMixin,
3838
_fetch_state_dict,
3939
_load_lora_into_text_encoder,
40+
_pack_sd_with_prefix,
4041
)
4142
from .lora_conversion_utils import (
4243
_convert_bfl_flux_control_lora_to_diffusers,
@@ -197,7 +198,8 @@ def load_lora_weights(
197198
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
198199

199200
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
200-
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
201+
kwargs["return_lora_metadata"] = True
202+
state_dict, network_alphas, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
201203

202204
is_correct_format = all("lora" in key for key in state_dict.keys())
203205
if not is_correct_format:
@@ -208,6 +210,7 @@ def load_lora_weights(
208210
network_alphas=network_alphas,
209211
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
210212
adapter_name=adapter_name,
213+
metadata=metadata,
211214
_pipeline=self,
212215
low_cpu_mem_usage=low_cpu_mem_usage,
213216
hotswap=hotswap,
@@ -221,6 +224,7 @@ def load_lora_weights(
221224
lora_scale=self.lora_scale,
222225
adapter_name=adapter_name,
223226
_pipeline=self,
227+
metadata=metadata,
224228
low_cpu_mem_usage=low_cpu_mem_usage,
225229
hotswap=hotswap,
226230
)
@@ -277,6 +281,7 @@ def lora_state_dict(
277281
The subfolder location of a model file within a larger model repository on the Hub or locally.
278282
weight_name (`str`, *optional*, defaults to None):
279283
Name of the serialized state dict file.
284+
return_lora_metadata: TODO
280285
"""
281286
# Load the main state dict first which has the LoRA layers for either of
282287
# UNet and text encoder or both.
@@ -290,6 +295,7 @@ def lora_state_dict(
290295
weight_name = kwargs.pop("weight_name", None)
291296
unet_config = kwargs.pop("unet_config", None)
292297
use_safetensors = kwargs.pop("use_safetensors", None)
298+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
293299

294300
allow_pickle = False
295301
if use_safetensors is None:
@@ -301,7 +307,7 @@ def lora_state_dict(
301307
"framework": "pytorch",
302308
}
303309

304-
state_dict = _fetch_state_dict(
310+
state_dict, metadata = _fetch_state_dict(
305311
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
306312
weight_name=weight_name,
307313
use_safetensors=use_safetensors,
@@ -338,7 +344,8 @@ def lora_state_dict(
338344
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
339345
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
340346

341-
return state_dict, network_alphas
347+
out = (state_dict, network_alphas, metadata) if return_lora_metadata else (state_dict, network_alphas)
348+
return out
342349

343350
@classmethod
344351
def load_lora_into_unet(
@@ -347,6 +354,7 @@ def load_lora_into_unet(
347354
network_alphas,
348355
unet,
349356
adapter_name=None,
357+
metadata=None,
350358
_pipeline=None,
351359
low_cpu_mem_usage=False,
352360
hotswap: bool = False,
@@ -391,6 +399,7 @@ def load_lora_into_unet(
391399
prefix=cls.unet_name,
392400
network_alphas=network_alphas,
393401
adapter_name=adapter_name,
402+
metadata=metadata,
394403
_pipeline=_pipeline,
395404
low_cpu_mem_usage=low_cpu_mem_usage,
396405
hotswap=hotswap,
@@ -405,6 +414,7 @@ def load_lora_into_text_encoder(
405414
prefix=None,
406415
lora_scale=1.0,
407416
adapter_name=None,
417+
metadata=None,
408418
_pipeline=None,
409419
low_cpu_mem_usage=False,
410420
hotswap: bool = False,
@@ -430,6 +440,7 @@ def load_lora_into_text_encoder(
430440
adapter_name (`str`, *optional*):
431441
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
432442
`default_{i}` where i is the total number of adapters being loaded.
443+
metadata: TODO
433444
low_cpu_mem_usage (`bool`, *optional*):
434445
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
435446
weights.
@@ -444,6 +455,7 @@ def load_lora_into_text_encoder(
444455
prefix=prefix,
445456
text_encoder_name=cls.text_encoder_name,
446457
adapter_name=adapter_name,
458+
metadata=metadata,
447459
_pipeline=_pipeline,
448460
low_cpu_mem_usage=low_cpu_mem_usage,
449461
hotswap=hotswap,
@@ -500,11 +512,13 @@ def save_lora_weights(
500512
if text_encoder_lora_layers:
501513
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
502514

503-
if unet_lora_adapter_metadata is not None:
504-
lora_adapter_metadata.update(cls.pack_weights(unet_lora_adapter_metadata, cls.unet_name))
515+
if unet_lora_adapter_metadata:
516+
lora_adapter_metadata.update(_pack_sd_with_prefix(unet_lora_adapter_metadata, cls.unet_name))
505517

506518
if text_encoder_lora_adapter_metadata:
507-
lora_adapter_metadata.update(cls.pack_weights(text_encoder_lora_adapter_metadata, cls.text_encoder_name))
519+
lora_adapter_metadata.update(
520+
_pack_sd_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
521+
)
508522

509523
# Save the model
510524
cls.write_lora_layers(

src/diffusers/loaders/peft.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,11 @@ def load_lora_adapter(
185185
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
186186
limitations to this technique, which are documented here:
187187
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
188-
188+
metadata: TODO
189189
"""
190190
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
191191
from peft.tuners.tuners_utils import BaseTunerLayer
192192

193-
from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY
194-
195193
cache_dir = kwargs.pop("cache_dir", None)
196194
force_download = kwargs.pop("force_download", False)
197195
proxies = kwargs.pop("proxies", None)
@@ -205,19 +203,17 @@ def load_lora_adapter(
205203
network_alphas = kwargs.pop("network_alphas", None)
206204
_pipeline = kwargs.pop("_pipeline", None)
207205
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
206+
metadata = kwargs.pop("metadata", None)
208207
allow_pickle = False
209208

210209
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
211210
raise ValueError(
212211
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
213212
)
214213

215-
user_agent = {
216-
"file_type": "attn_procs_weights",
217-
"framework": "pytorch",
218-
}
214+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
219215

220-
state_dict = _fetch_state_dict(
216+
state_dict, metadata = _fetch_state_dict(
221217
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
222218
weight_name=weight_name,
223219
use_safetensors=use_safetensors,
@@ -230,17 +226,17 @@ def load_lora_adapter(
230226
subfolder=subfolder,
231227
user_agent=user_agent,
232228
allow_pickle=allow_pickle,
229+
metadata=metadata,
233230
)
234-
metadata = None
235-
if LORA_ADAPTER_METADATA_KEY in state_dict:
236-
metadata = state_dict[LORA_ADAPTER_METADATA_KEY]
237231
if network_alphas is not None and prefix is None:
238232
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
233+
if network_alphas and metadata:
234+
raise ValueError("Both `network_alphas` and `metadata` cannot be specified.")
239235

240236
if prefix is not None:
241237
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
242-
if metadata is not None:
243-
state_dict[LORA_ADAPTER_METADATA_KEY] = metadata
238+
if metadata is not None:
239+
metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
244240

245241
if len(state_dict) > 0:
246242
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
@@ -275,12 +271,15 @@ def load_lora_adapter(
275271
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
276272
}
277273

278-
lora_config_kwargs = get_peft_kwargs(
279-
rank,
280-
network_alpha_dict=network_alphas,
281-
peft_state_dict=state_dict,
282-
prefix=prefix,
283-
)
274+
if metadata is not None:
275+
lora_config_kwargs = metadata
276+
else:
277+
lora_config_kwargs = get_peft_kwargs(
278+
rank,
279+
network_alpha_dict=network_alphas,
280+
peft_state_dict=state_dict,
281+
prefix=prefix,
282+
)
284283
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
285284

286285
if "use_dora" in lora_config_kwargs:

src/diffusers/utils/state_dict_utils.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -350,21 +350,16 @@ def state_dict_all_zero(state_dict, filter_str=None):
350350
return all(torch.all(param == 0).item() for param in state_dict.values())
351351

352352

353-
def _maybe_populate_state_dict_with_metadata(state_dict, model_file):
354-
if not model_file.endswith(".safetensors"):
355-
return state_dict
356-
353+
def _load_sft_state_dict_metadata(model_file: str):
357354
import safetensors.torch
358355

359356
from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY
360357

361-
metadata_key = LORA_ADAPTER_METADATA_KEY
358+
metadata = None
362359
with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f:
363-
if hasattr(f, "metadata"):
364-
metadata = f.metadata()
365-
if metadata is not None:
366-
metadata_keys = list(metadata.keys())
367-
if not (len(metadata_keys) == 1 and metadata_keys[0] == "format"):
368-
peft_metadata = {k: v for k, v in metadata.items() if k != "format"}
369-
state_dict["lora_adapter_metadata"] = json.loads(peft_metadata[metadata_key])
370-
return state_dict
360+
metadata = f.metadata()
361+
if metadata is not None:
362+
metadata_keys = list(metadata.keys())
363+
if not (len(metadata_keys) == 1 and metadata_keys[0] == "format"):
364+
metadata = json.loads(metadata[LORA_ADAPTER_METADATA_KEY])
365+
return metadata

0 commit comments

Comments
 (0)