Skip to content

Commit 7ec4ef4

Browse files
committed
smol updates
1 parent 42bb6bc commit 7ec4ef4

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

src/diffusers/loaders/peft.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,11 +239,11 @@ def load_lora_adapter(
239239
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
240240

241241
if prefix is not None:
242-
metadata = state_dict.pop("lora_metadata", None)
242+
metadata = state_dict.pop("lora_adapter_metadata", None)
243243
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
244244

245245
if metadata is not None:
246-
state_dict["lora_metadata"] = metadata
246+
state_dict["lora_adapter_metadata"] = metadata
247247

248248
if len(state_dict) > 0:
249249
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:

src/diffusers/utils/peft_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,9 @@ def get_peft_kwargs(
151151
rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, prefix=None, load_with_metadata=False
152152
):
153153
if load_with_metadata:
154-
if "lora_metadata" not in peft_state_dict:
155-
raise ValueError("Couldn't find 'lora_metadata' key in the `peft_state_dict`.")
156-
metadata = peft_state_dict["lora_metadata"]
154+
if "lora_adapter_metadata" not in peft_state_dict:
155+
raise ValueError("Couldn't find 'lora_adapter_metadata' key in the `peft_state_dict`.")
156+
metadata = peft_state_dict["lora_adapter_metadata"]
157157
if prefix is not None:
158158
metadata = {k.replace(f"{prefix}.", ""): v for k, v in metadata.items()}
159159
return metadata

src/diffusers/utils/state_dict_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def _maybe_populate_state_dict_with_metadata(state_dict, model_file, metadata_ke
360360
metadata_keys = list(metadata.keys())
361361
if not (len(metadata_keys) == 1 and metadata_keys[0] == "format"):
362362
peft_metadata = {k: v for k, v in metadata.items() if k != "format"}
363-
state_dict["lora_metadata"] = json.loads(peft_metadata[metadata_key])
363+
state_dict["lora_adapter_metadata"] = json.loads(peft_metadata[metadata_key])
364364
else:
365365
raise ValueError("Metadata couldn't be parsed from the safetensors file.")
366366
return state_dict

0 commit comments

Comments
 (0)