Skip to content

Commit f4d4179

Browse files
committed
fixes
1 parent 99fe09c commit f4d4179

File tree

6 files changed

+32
-66
lines changed

6 files changed

+32
-66
lines changed

examples/community/ip_adapter_face_id.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,7 @@ def load_ip_adapter_face_id(self, pretrained_model_name_or_path_or_dict, weight_
282282
revision = kwargs.pop("revision", None)
283283
subfolder = kwargs.pop("subfolder", None)
284284

285-
user_agent = {
286-
"file_type": "attn_procs_weights",
287-
"framework": "pytorch",
288-
}
285+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
289286
model_file = _get_model_file(
290287
pretrained_model_name_or_path_or_dict,
291288
weights_name=weight_name,

src/diffusers/loaders/ip_adapter.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,7 @@ def load_ip_adapter(
159159
" `low_cpu_mem_usage=False`."
160160
)
161161

162-
user_agent = {
163-
"file_type": "attn_procs_weights",
164-
"framework": "pytorch",
165-
}
162+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
166163
state_dicts = []
167164
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
168165
pretrained_model_name_or_path_or_dict, weight_name, subfolder
@@ -465,10 +462,7 @@ def load_ip_adapter(
465462
" `low_cpu_mem_usage=False`."
466463
)
467464

468-
user_agent = {
469-
"file_type": "attn_procs_weights",
470-
"framework": "pytorch",
471-
}
465+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
472466
state_dicts = []
473467
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
474468
pretrained_model_name_or_path_or_dict, weight_name, subfolder
@@ -750,10 +744,7 @@ def load_ip_adapter(
750744
" `low_cpu_mem_usage=False`."
751745
)
752746

753-
user_agent = {
754-
"file_type": "attn_procs_weights",
755-
"framework": "pytorch",
756-
}
747+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
757748

758749
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
759750
model_file = _get_model_file(

src/diffusers/loaders/lora_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def _load_lora_into_text_encoder(
398398
if metadata is not None:
399399
lora_config_kwargs = metadata
400400
else:
401-
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False, prefix=prefix)
401+
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
402402

403403
if "use_dora" in lora_config_kwargs:
404404
if lora_config_kwargs["use_dora"]:

src/diffusers/loaders/lora_pipeline.py

Lines changed: 25 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -302,10 +302,7 @@ def lora_state_dict(
302302
use_safetensors = True
303303
allow_pickle = True
304304

305-
user_agent = {
306-
"file_type": "attn_procs_weights",
307-
"framework": "pytorch",
308-
}
305+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
309306

310307
state_dict, metadata = _fetch_state_dict(
311308
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
@@ -785,10 +782,7 @@ def lora_state_dict(
785782
use_safetensors = True
786783
allow_pickle = True
787784

788-
user_agent = {
789-
"file_type": "attn_procs_weights",
790-
"framework": "pytorch",
791-
}
785+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
792786

793787
state_dict, metadata = _fetch_state_dict(
794788
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
@@ -1954,7 +1948,7 @@ def lora_state_dict(
19541948
allowed by Git.
19551949
subfolder (`str`, *optional*, defaults to `""`):
19561950
The subfolder location of a model file within a larger model repository on the Hub or locally.
1957-
1951+
return_lora_metadata: TODO
19581952
"""
19591953
# Load the main state dict first which has the LoRA layers for either of
19601954
# transformer and text encoder or both.
@@ -1967,18 +1961,16 @@ def lora_state_dict(
19671961
subfolder = kwargs.pop("subfolder", None)
19681962
weight_name = kwargs.pop("weight_name", None)
19691963
use_safetensors = kwargs.pop("use_safetensors", None)
1964+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
19701965

19711966
allow_pickle = False
19721967
if use_safetensors is None:
19731968
use_safetensors = True
19741969
allow_pickle = True
19751970

1976-
user_agent = {
1977-
"file_type": "attn_procs_weights",
1978-
"framework": "pytorch",
1979-
}
1971+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
19801972

1981-
state_dict = _fetch_state_dict(
1973+
state_dict, metadata = _fetch_state_dict(
19821974
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
19831975
weight_name=weight_name,
19841976
use_safetensors=use_safetensors,
@@ -2032,10 +2024,12 @@ def lora_state_dict(
20322024
f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."
20332025
)
20342026

2027+
outputs = [state_dict]
20352028
if return_alphas:
2036-
return state_dict, network_alphas
2037-
else:
2038-
return state_dict
2029+
outputs.append(network_alphas)
2030+
if return_lora_metadata:
2031+
outputs.append(metadata)
2032+
return tuple(outputs)
20392033

20402034
def load_lora_weights(
20412035
self,
@@ -2084,7 +2078,8 @@ def load_lora_weights(
20842078
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
20852079

20862080
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
2087-
state_dict, network_alphas = self.lora_state_dict(
2081+
kwargs["return_lora_metadata"] = True
2082+
state_dict, network_alphas, metadata = self.lora_state_dict(
20882083
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
20892084
)
20902085

@@ -2135,6 +2130,7 @@ def load_lora_weights(
21352130
network_alphas=network_alphas,
21362131
transformer=transformer,
21372132
adapter_name=adapter_name,
2133+
metadata=metadata,
21382134
_pipeline=self,
21392135
low_cpu_mem_usage=low_cpu_mem_usage,
21402136
hotswap=hotswap,
@@ -2154,6 +2150,7 @@ def load_lora_weights(
21542150
prefix=self.text_encoder_name,
21552151
lora_scale=self.lora_scale,
21562152
adapter_name=adapter_name,
2153+
metadata=metadata,
21572154
_pipeline=self,
21582155
low_cpu_mem_usage=low_cpu_mem_usage,
21592156
hotswap=hotswap,
@@ -3661,10 +3658,7 @@ def lora_state_dict(
36613658
use_safetensors = True
36623659
allow_pickle = True
36633660

3664-
user_agent = {
3665-
"file_type": "attn_procs_weights",
3666-
"framework": "pytorch",
3667-
}
3661+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
36683662

36693663
state_dict, metadata = _fetch_state_dict(
36703664
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
@@ -4339,10 +4333,7 @@ def lora_state_dict(
43394333
use_safetensors = True
43404334
allow_pickle = True
43414335

4342-
user_agent = {
4343-
"file_type": "attn_procs_weights",
4344-
"framework": "pytorch",
4345-
}
4336+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
43464337

43474338
state_dict = _fetch_state_dict(
43484339
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
@@ -4679,10 +4670,7 @@ def lora_state_dict(
46794670
use_safetensors = True
46804671
allow_pickle = True
46814672

4682-
user_agent = {
4683-
"file_type": "attn_procs_weights",
4684-
"framework": "pytorch",
4685-
}
4673+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
46864674

46874675
state_dict = _fetch_state_dict(
46884676
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
@@ -5001,7 +4989,7 @@ def lora_state_dict(
50014989
allowed by Git.
50024990
subfolder (`str`, *optional*, defaults to `""`):
50034991
The subfolder location of a model file within a larger model repository on the Hub or locally.
5004-
4992+
return_lora_metadata: TODO
50054993
"""
50064994
# Load the main state dict first which has the LoRA layers for either of
50074995
# transformer and text encoder or both.
@@ -5014,18 +5002,16 @@ def lora_state_dict(
50145002
subfolder = kwargs.pop("subfolder", None)
50155003
weight_name = kwargs.pop("weight_name", None)
50165004
use_safetensors = kwargs.pop("use_safetensors", None)
5005+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
50175006

50185007
allow_pickle = False
50195008
if use_safetensors is None:
50205009
use_safetensors = True
50215010
allow_pickle = True
50225011

5023-
user_agent = {
5024-
"file_type": "attn_procs_weights",
5025-
"framework": "pytorch",
5026-
}
5012+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
50275013

5028-
state_dict = _fetch_state_dict(
5014+
state_dict, metadata = _fetch_state_dict(
50295015
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
50305016
weight_name=weight_name,
50315017
use_safetensors=use_safetensors,
@@ -5050,7 +5036,8 @@ def lora_state_dict(
50505036
logger.warning(warn_msg)
50515037
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
50525038

5053-
return state_dict
5039+
out = (state_dict, metadata) if return_lora_metadata else state_dict
5040+
return out
50545041

50555042
@classmethod
50565043
def _maybe_expand_t2v_lora_for_i2v(
@@ -5746,10 +5733,7 @@ def lora_state_dict(
57465733
use_safetensors = True
57475734
allow_pickle = True
57485735

5749-
user_agent = {
5750-
"file_type": "attn_procs_weights",
5751-
"framework": "pytorch",
5752-
}
5736+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
57535737

57545738
state_dict = _fetch_state_dict(
57555739
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,

src/diffusers/loaders/peft.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,10 +275,7 @@ def load_lora_adapter(
275275
lora_config_kwargs = metadata
276276
else:
277277
lora_config_kwargs = get_peft_kwargs(
278-
rank,
279-
network_alpha_dict=network_alphas,
280-
peft_state_dict=state_dict,
281-
prefix=prefix,
278+
rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict
282279
)
283280
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
284281

src/diffusers/loaders/unet.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
155155
use_safetensors = True
156156
allow_pickle = True
157157

158-
user_agent = {
159-
"file_type": "attn_procs_weights",
160-
"framework": "pytorch",
161-
}
158+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
162159

163160
model_file = None
164161
if not isinstance(pretrained_model_name_or_path_or_dict, dict):

0 commit comments

Comments
 (0)