@@ -302,10 +302,7 @@ def lora_state_dict(
302302 use_safetensors = True
303303 allow_pickle = True
304304
305- user_agent = {
306- "file_type" : "attn_procs_weights" ,
307- "framework" : "pytorch" ,
308- }
305+ user_agent = {"file_type" : "attn_procs_weights" , "framework" : "pytorch" }
309306
310307 state_dict , metadata = _fetch_state_dict (
311308 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
@@ -785,10 +782,7 @@ def lora_state_dict(
785782 use_safetensors = True
786783 allow_pickle = True
787784
788- user_agent = {
789- "file_type" : "attn_procs_weights" ,
790- "framework" : "pytorch" ,
791- }
785+ user_agent = {"file_type" : "attn_procs_weights" , "framework" : "pytorch" }
792786
793787 state_dict , metadata = _fetch_state_dict (
794788 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
@@ -1954,7 +1948,7 @@ def lora_state_dict(
19541948 allowed by Git.
19551949 subfolder (`str`, *optional*, defaults to `""`):
19561950 The subfolder location of a model file within a larger model repository on the Hub or locally.
1957-
1951+ return_lora_metadata: TODO
19581952 """
19591953 # Load the main state dict first which has the LoRA layers for either of
19601954 # transformer and text encoder or both.
@@ -1967,18 +1961,16 @@ def lora_state_dict(
19671961 subfolder = kwargs .pop ("subfolder" , None )
19681962 weight_name = kwargs .pop ("weight_name" , None )
19691963 use_safetensors = kwargs .pop ("use_safetensors" , None )
1964+ return_lora_metadata = kwargs .pop ("return_lora_metadata" , False )
19701965
19711966 allow_pickle = False
19721967 if use_safetensors is None :
19731968 use_safetensors = True
19741969 allow_pickle = True
19751970
1976- user_agent = {
1977- "file_type" : "attn_procs_weights" ,
1978- "framework" : "pytorch" ,
1979- }
1971+ user_agent = {"file_type" : "attn_procs_weights" , "framework" : "pytorch" }
19801972
1981- state_dict = _fetch_state_dict (
1973+ state_dict , metadata = _fetch_state_dict (
19821974 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
19831975 weight_name = weight_name ,
19841976 use_safetensors = use_safetensors ,
@@ -2032,10 +2024,12 @@ def lora_state_dict(
20322024 f"The alpha key ({ k } ) seems to be incorrect. If you think this error is unexpected, please open as issue."
20332025 )
20342026
2027+ outputs = [state_dict ]
20352028 if return_alphas :
2036- return state_dict , network_alphas
2037- else :
2038- return state_dict
2029+ outputs .append (network_alphas )
2030+ if return_lora_metadata :
2031+ outputs .append (metadata )
2032+ return tuple (outputs )
20392033
20402034 def load_lora_weights (
20412035 self ,
@@ -2084,7 +2078,8 @@ def load_lora_weights(
20842078 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
20852079
20862080 # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
2087- state_dict , network_alphas = self .lora_state_dict (
2081+ kwargs ["return_lora_metadata" ] = True
2082+ state_dict , network_alphas , metadata = self .lora_state_dict (
20882083 pretrained_model_name_or_path_or_dict , return_alphas = True , ** kwargs
20892084 )
20902085
@@ -2135,6 +2130,7 @@ def load_lora_weights(
21352130 network_alphas = network_alphas ,
21362131 transformer = transformer ,
21372132 adapter_name = adapter_name ,
2133+ metadata = metadata ,
21382134 _pipeline = self ,
21392135 low_cpu_mem_usage = low_cpu_mem_usage ,
21402136 hotswap = hotswap ,
@@ -2154,6 +2150,7 @@ def load_lora_weights(
21542150 prefix = self .text_encoder_name ,
21552151 lora_scale = self .lora_scale ,
21562152 adapter_name = adapter_name ,
2153+ metadata = metadata ,
21572154 _pipeline = self ,
21582155 low_cpu_mem_usage = low_cpu_mem_usage ,
21592156 hotswap = hotswap ,
@@ -3661,10 +3658,7 @@ def lora_state_dict(
36613658 use_safetensors = True
36623659 allow_pickle = True
36633660
3664- user_agent = {
3665- "file_type" : "attn_procs_weights" ,
3666- "framework" : "pytorch" ,
3667- }
3661+ user_agent = {"file_type" : "attn_procs_weights" , "framework" : "pytorch" }
36683662
36693663 state_dict , metadata = _fetch_state_dict (
36703664 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
@@ -4339,10 +4333,7 @@ def lora_state_dict(
43394333 use_safetensors = True
43404334 allow_pickle = True
43414335
4342- user_agent = {
4343- "file_type" : "attn_procs_weights" ,
4344- "framework" : "pytorch" ,
4345- }
4336+ user_agent = {"file_type" : "attn_procs_weights" , "framework" : "pytorch" }
43464337
43474338 state_dict = _fetch_state_dict (
43484339 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
@@ -4679,10 +4670,7 @@ def lora_state_dict(
46794670 use_safetensors = True
46804671 allow_pickle = True
46814672
4682- user_agent = {
4683- "file_type" : "attn_procs_weights" ,
4684- "framework" : "pytorch" ,
4685- }
4673+ user_agent = {"file_type" : "attn_procs_weights" , "framework" : "pytorch" }
46864674
46874675 state_dict = _fetch_state_dict (
46884676 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
@@ -5001,7 +4989,7 @@ def lora_state_dict(
50014989 allowed by Git.
50024990 subfolder (`str`, *optional*, defaults to `""`):
50034991 The subfolder location of a model file within a larger model repository on the Hub or locally.
5004-
4992+ return_lora_metadata: TODO
50054993 """
50064994 # Load the main state dict first which has the LoRA layers for either of
50074995 # transformer and text encoder or both.
@@ -5014,18 +5002,16 @@ def lora_state_dict(
50145002 subfolder = kwargs .pop ("subfolder" , None )
50155003 weight_name = kwargs .pop ("weight_name" , None )
50165004 use_safetensors = kwargs .pop ("use_safetensors" , None )
5005+ return_lora_metadata = kwargs .pop ("return_lora_metadata" , False )
50175006
50185007 allow_pickle = False
50195008 if use_safetensors is None :
50205009 use_safetensors = True
50215010 allow_pickle = True
50225011
5023- user_agent = {
5024- "file_type" : "attn_procs_weights" ,
5025- "framework" : "pytorch" ,
5026- }
5012+ user_agent = {"file_type" : "attn_procs_weights" , "framework" : "pytorch" }
50275013
5028- state_dict = _fetch_state_dict (
5014+ state_dict , metadata = _fetch_state_dict (
50295015 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
50305016 weight_name = weight_name ,
50315017 use_safetensors = use_safetensors ,
@@ -5050,7 +5036,8 @@ def lora_state_dict(
50505036 logger .warning (warn_msg )
50515037 state_dict = {k : v for k , v in state_dict .items () if "dora_scale" not in k }
50525038
5053- return state_dict
5039+ out = (state_dict , metadata ) if return_lora_metadata else state_dict
5040+ return out
50545041
50555042 @classmethod
50565043 def _maybe_expand_t2v_lora_for_i2v (
@@ -5746,10 +5733,7 @@ def lora_state_dict(
57465733 use_safetensors = True
57475734 allow_pickle = True
57485735
5749- user_agent = {
5750- "file_type" : "attn_procs_weights" ,
5751- "framework" : "pytorch" ,
5752- }
5736+ user_agent = {"file_type" : "attn_procs_weights" , "framework" : "pytorch" }
57535737
57545738 state_dict = _fetch_state_dict (
57555739 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
0 commit comments