Skip to content

Commit 1348463

Browse files
committed
simplify _load_sft_state_dict_metadata
1 parent c4bd1c7 commit 1348463

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

src/diffusers/utils/state_dict_utils.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -355,11 +355,9 @@ def _load_sft_state_dict_metadata(model_file: str):
355355

356356
from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY
357357

358-
metadata = None
359358
with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f:
360-
metadata = f.metadata()
361-
if metadata is not None:
362-
metadata_keys = list(metadata.keys())
363-
if not (len(metadata_keys) == 1 and metadata_keys[0] == "format"):
364-
metadata = json.loads(metadata[LORA_ADAPTER_METADATA_KEY])
365-
return metadata
359+
metadata = f.metadata() or {}
360+
361+
metadata.pop("format", None)
362+
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
363+
return json.loads(raw) if raw else None

0 commit comments

Comments
 (0)