Skip to content

Commit ded2fd6

Browse files
committed
automatically save metadata in save_lora_adapter.
1 parent 7f59ca0 commit ded2fd6

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

src/diffusers/loaders/peft.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ def load_lora_adapter(
193193
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
194194
from peft.tuners.tuners_utils import BaseTunerLayer
195195

196+
from .lora_base import LORA_ADAPTER_METADATA_KEY
197+
196198
cache_dir = kwargs.pop("cache_dir", None)
197199
force_download = kwargs.pop("force_download", False)
198200
proxies = kwargs.pop("proxies", None)
@@ -236,11 +238,11 @@ def load_lora_adapter(
236238
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
237239

238240
if prefix is not None:
239-
metadata = state_dict.pop("lora_adapter_metadata", None)
241+
metadata = state_dict.pop(LORA_ADAPTER_METADATA_KEY, None)
240242
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
241243

242244
if metadata is not None:
243-
state_dict["lora_adapter_metadata"] = metadata
245+
state_dict[LORA_ADAPTER_METADATA_KEY] = metadata
244246

245247
if len(state_dict) > 0:
246248
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
@@ -464,23 +466,19 @@ def save_lora_adapter(
464466
safe_serialization (`bool`, *optional*, defaults to `True`):
465467
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
466468
weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with.
467-
lora_adapter_metadata: TODO
468469
"""
469470
from peft.utils import get_peft_model_state_dict
470471

471-
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
472-
473-
if lora_adapter_metadata is not None and not safe_serialization:
474-
raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.")
475-
if not isinstance(lora_adapter_metadata, dict):
476-
raise ValueError("`lora_adapter_metadata` must be of type `dict`.")
472+
from .lora_base import LORA_ADAPTER_METADATA_KEY, LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
477473

478474
if adapter_name is None:
479475
adapter_name = get_adapter_name(self)
480476

481477
if adapter_name not in getattr(self, "peft_config", {}):
482478
raise ValueError(f"Adapter name {adapter_name} not found in the model.")
483479

480+
lora_adapter_metadata = self.peft_config[adapter_name]
481+
484482
lora_layers_to_save = get_peft_model_state_dict(
485483
self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name
486484
)
@@ -497,7 +495,7 @@ def save_function(weights, filename):
497495
for key, value in lora_adapter_metadata.items():
498496
if isinstance(value, set):
499497
lora_adapter_metadata[key] = list(value)
500-
metadata["lora_adapter_metadata"] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
498+
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
501499

502500
return safetensors.torch.save_file(weights, filename, metadata=metadata)
503501

@@ -512,7 +510,6 @@ def save_function(weights, filename):
512510
else:
513511
weight_name = LORA_WEIGHT_NAME
514512

515-
# TODO: we could consider saving the `peft_config` as well.
516513
save_path = Path(save_directory, weight_name).as_posix()
517514
save_function(lora_layers_to_save, save_path)
518515
logger.info(f"Model weights saved in {save_path}")

0 commit comments

Comments
 (0)