@@ -1612,6 +1612,7 @@ def lora_state_dict(
16121612 allowed by Git.
16131613 subfolder (`str`, *optional*, defaults to `""`):
16141614 The subfolder location of a model file within a larger model repository on the Hub or locally.
1615+ return_lora_metadata: TODO
16151616
16161617 """
16171618 # Load the main state dict first which has the LoRA layers for either of
@@ -1625,18 +1626,16 @@ def lora_state_dict(
16251626 subfolder = kwargs .pop ("subfolder" , None )
16261627 weight_name = kwargs .pop ("weight_name" , None )
16271628 use_safetensors = kwargs .pop ("use_safetensors" , None )
1629+ return_lora_metadata = kwargs .pop ("return_lora_metadata" , False )
16281630
16291631 allow_pickle = False
16301632 if use_safetensors is None :
16311633 use_safetensors = True
16321634 allow_pickle = True
16331635
1634- user_agent = {
1635- "file_type" : "attn_procs_weights" ,
1636- "framework" : "pytorch" ,
1637- }
1636+ user_agent = {"file_type" : "attn_procs_weights" , "framework" : "pytorch" }
16381637
1639- state_dict = _fetch_state_dict (
1638+ state_dict , metadata = _fetch_state_dict (
16401639 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
16411640 weight_name = weight_name ,
16421641 use_safetensors = use_safetensors ,
@@ -1657,7 +1656,8 @@ def lora_state_dict(
16571656 logger .warning (warn_msg )
16581657 state_dict = {k : v for k , v in state_dict .items () if "dora_scale" not in k }
16591658
1660- return state_dict
1659+ out = (state_dict , metadata ) if return_lora_metadata else state_dict
1660+ return out
16611661
16621662 # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
16631663 def load_lora_weights (
@@ -1702,7 +1702,8 @@ def load_lora_weights(
17021702 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
17031703
17041704 # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
1705- state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
1705+ kwargs ["return_lora_metadata" ] = True
1706+ state_dict , metadata = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
17061707
17071708 is_correct_format = all ("lora" in key for key in state_dict .keys ())
17081709 if not is_correct_format :
@@ -1712,6 +1713,7 @@ def load_lora_weights(
17121713 state_dict ,
17131714 transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
17141715 adapter_name = adapter_name ,
1716+ metadata = metadata ,
17151717 _pipeline = self ,
17161718 low_cpu_mem_usage = low_cpu_mem_usage ,
17171719 hotswap = hotswap ,
@@ -3058,7 +3060,8 @@ def load_lora_weights(
30583060 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
30593061
30603062 # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3061- state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
3063+ kwargs ["return_lora_metadata" ] = True
3064+ state_dict , metadata = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
30623065
30633066 is_correct_format = all ("lora" in key for key in state_dict .keys ())
30643067 if not is_correct_format :
@@ -3068,6 +3071,7 @@ def load_lora_weights(
30683071 state_dict ,
30693072 transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
30703073 adapter_name = adapter_name ,
3074+ metadata = metadata ,
30713075 _pipeline = self ,
30723076 low_cpu_mem_usage = low_cpu_mem_usage ,
30733077 hotswap = hotswap ,
@@ -3391,7 +3395,8 @@ def load_lora_weights(
33913395 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
33923396
33933397 # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3394- state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
3398+ kwargs ["return_lora_metadata" ] = True
3399+ state_dict , metadata = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
33953400
33963401 is_correct_format = all ("lora" in key for key in state_dict .keys ())
33973402 if not is_correct_format :
@@ -3401,6 +3406,7 @@ def load_lora_weights(
34013406 state_dict ,
34023407 transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
34033408 adapter_name = adapter_name ,
3409+ metadata = metadata ,
34043410 _pipeline = self ,
34053411 low_cpu_mem_usage = low_cpu_mem_usage ,
34063412 hotswap = hotswap ,
@@ -3635,7 +3641,7 @@ def lora_state_dict(
36353641 allowed by Git.
36363642 subfolder (`str`, *optional*, defaults to `""`):
36373643 The subfolder location of a model file within a larger model repository on the Hub or locally.
3638-
3644+ return_lora_metadata: TODO
36393645 """
36403646 # Load the main state dict first which has the LoRA layers for either of
36413647 # transformer and text encoder or both.
@@ -3648,6 +3654,7 @@ def lora_state_dict(
36483654 subfolder = kwargs .pop ("subfolder" , None )
36493655 weight_name = kwargs .pop ("weight_name" , None )
36503656 use_safetensors = kwargs .pop ("use_safetensors" , None )
3657+ return_lora_metadata = kwargs .pop ("return_lora_metadata" , False )
36513658
36523659 allow_pickle = False
36533660 if use_safetensors is None :
@@ -3659,7 +3666,7 @@ def lora_state_dict(
36593666 "framework" : "pytorch" ,
36603667 }
36613668
3662- state_dict = _fetch_state_dict (
3669+ state_dict , metadata = _fetch_state_dict (
36633670 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
36643671 weight_name = weight_name ,
36653672 use_safetensors = use_safetensors ,
@@ -3684,7 +3691,8 @@ def lora_state_dict(
36843691 if is_non_diffusers_format :
36853692 state_dict = _convert_non_diffusers_ltxv_lora_to_diffusers (state_dict )
36863693
3687- return state_dict
3694+ out = (state_dict , metadata ) if return_lora_metadata else state_dict
3695+ return out
36883696
36893697 # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
36903698 def load_lora_weights (
@@ -3729,7 +3737,8 @@ def load_lora_weights(
37293737 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
37303738
37313739 # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3732- state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
3740+ kwargs ["return_lora_metadata" ] = True
3741+ state_dict , metadata = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
37333742
37343743 is_correct_format = all ("lora" in key for key in state_dict .keys ())
37353744 if not is_correct_format :
@@ -3739,6 +3748,7 @@ def load_lora_weights(
37393748 state_dict ,
37403749 transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
37413750 adapter_name = adapter_name ,
3751+ metadata = metadata ,
37423752 _pipeline = self ,
37433753 low_cpu_mem_usage = low_cpu_mem_usage ,
37443754 hotswap = hotswap ,
@@ -4064,7 +4074,8 @@ def load_lora_weights(
40644074 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
40654075
40664076 # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
4067- state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
4077+ kwargs ["return_lora_metadata" ] = True
4078+ state_dict , metadata = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
40684079
40694080 is_correct_format = all ("lora" in key for key in state_dict .keys ())
40704081 if not is_correct_format :
@@ -4074,6 +4085,7 @@ def load_lora_weights(
40744085 state_dict ,
40754086 transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
40764087 adapter_name = adapter_name ,
4088+ metadata = metadata ,
40774089 _pipeline = self ,
40784090 low_cpu_mem_usage = low_cpu_mem_usage ,
40794091 hotswap = hotswap ,
@@ -4402,7 +4414,8 @@ def load_lora_weights(
44024414 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
44034415
44044416 # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
4405- state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
4417+ kwargs ["return_lora_metadata" ] = True
4418+ state_dict , metadata = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
44064419
44074420 is_correct_format = all ("lora" in key for key in state_dict .keys ())
44084421 if not is_correct_format :
@@ -4412,6 +4425,7 @@ def load_lora_weights(
44124425 state_dict ,
44134426 transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
44144427 adapter_name = adapter_name ,
4428+ metadata = metadata ,
44154429 _pipeline = self ,
44164430 low_cpu_mem_usage = low_cpu_mem_usage ,
44174431 hotswap = hotswap ,
@@ -4741,7 +4755,8 @@ def load_lora_weights(
47414755 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
47424756
47434757 # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
4744- state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
4758+ kwargs ["return_lora_metadata" ] = True
4759+ state_dict , metadata = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
47454760
47464761 is_correct_format = all ("lora" in key for key in state_dict .keys ())
47474762 if not is_correct_format :
@@ -4751,6 +4766,7 @@ def load_lora_weights(
47514766 state_dict ,
47524767 transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
47534768 adapter_name = adapter_name ,
4769+ metadata = metadata ,
47544770 _pipeline = self ,
47554771 low_cpu_mem_usage = low_cpu_mem_usage ,
47564772 hotswap = hotswap ,
@@ -5375,6 +5391,7 @@ def lora_state_dict(
53755391 allowed by Git.
53765392 subfolder (`str`, *optional*, defaults to `""`):
53775393 The subfolder location of a model file within a larger model repository on the Hub or locally.
5394+ return_lora_metadata: TODO
53785395
53795396 """
53805397 # Load the main state dict first which has the LoRA layers for either of
@@ -5388,18 +5405,16 @@ def lora_state_dict(
53885405 subfolder = kwargs .pop ("subfolder" , None )
53895406 weight_name = kwargs .pop ("weight_name" , None )
53905407 use_safetensors = kwargs .pop ("use_safetensors" , None )
5408+ return_lora_metadata = kwargs .pop ("return_lora_metadata" , False )
53915409
53925410 allow_pickle = False
53935411 if use_safetensors is None :
53945412 use_safetensors = True
53955413 allow_pickle = True
53965414
5397- user_agent = {
5398- "file_type" : "attn_procs_weights" ,
5399- "framework" : "pytorch" ,
5400- }
5415+ user_agent = {"file_type" : "attn_procs_weights" , "framework" : "pytorch" }
54015416
5402- state_dict = _fetch_state_dict (
5417+ state_dict , metadata = _fetch_state_dict (
54035418 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
54045419 weight_name = weight_name ,
54055420 use_safetensors = use_safetensors ,
@@ -5420,7 +5435,8 @@ def lora_state_dict(
54205435 logger .warning (warn_msg )
54215436 state_dict = {k : v for k , v in state_dict .items () if "dora_scale" not in k }
54225437
5423- return state_dict
5438+ out = (state_dict , metadata ) if return_lora_metadata else state_dict
5439+ return out
54245440
54255441 # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
54265442 def load_lora_weights (
@@ -5465,7 +5481,8 @@ def load_lora_weights(
54655481 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
54665482
54675483 # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
5468- state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
5484+ kwargs ["return_lora_metadata" ] = True
5485+ state_dict , metadata = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
54695486
54705487 is_correct_format = all ("lora" in key for key in state_dict .keys ())
54715488 if not is_correct_format :
@@ -5475,6 +5492,7 @@ def load_lora_weights(
54755492 state_dict ,
54765493 transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
54775494 adapter_name = adapter_name ,
5495+ metadata = metadata ,
54785496 _pipeline = self ,
54795497 low_cpu_mem_usage = low_cpu_mem_usage ,
54805498 hotswap = hotswap ,
@@ -5803,7 +5821,8 @@ def load_lora_weights(
58035821 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
58045822
58055823 # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
5806- state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
5824+ kwargs ["return_lora_metadata" ] = True
5825+ state_dict , metadata = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
58075826
58085827 is_correct_format = all ("lora" in key for key in state_dict .keys ())
58095828 if not is_correct_format :
@@ -5813,6 +5832,7 @@ def load_lora_weights(
58135832 state_dict ,
58145833 transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
58155834 adapter_name = adapter_name ,
5835+ metadata = metadata ,
58165836 _pipeline = self ,
58175837 low_cpu_mem_usage = low_cpu_mem_usage ,
58185838 hotswap = hotswap ,
0 commit comments