Skip to content

Commit 46f4726

Browse files
committed
fixes
1 parent f4d4179 commit 46f4726

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4314,7 +4314,7 @@ def lora_state_dict(
43144314
allowed by Git.
43154315
subfolder (`str`, *optional*, defaults to `""`):
43164316
The subfolder location of a model file within a larger model repository on the Hub or locally.
4317-
4317+
return_lora_metadata: TODO
43184318
"""
43194319
# Load the main state dict first which has the LoRA layers for either of
43204320
# transformer and text encoder or both.
@@ -4327,6 +4327,7 @@ def lora_state_dict(
43274327
subfolder = kwargs.pop("subfolder", None)
43284328
weight_name = kwargs.pop("weight_name", None)
43294329
use_safetensors = kwargs.pop("use_safetensors", None)
4330+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
43304331

43314332
allow_pickle = False
43324333
if use_safetensors is None:
@@ -4335,7 +4336,7 @@ def lora_state_dict(
43354336

43364337
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
43374338

4338-
state_dict = _fetch_state_dict(
4339+
state_dict, metadata = _fetch_state_dict(
43394340
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
43404341
weight_name=weight_name,
43414342
use_safetensors=use_safetensors,
@@ -4360,7 +4361,8 @@ def lora_state_dict(
43604361
if is_original_hunyuan_video:
43614362
state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
43624363

4363-
return state_dict
4364+
out = (state_dict, metadata) if return_lora_metadata else state_dict
4365+
return out
43644366

43654367
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
43664368
def load_lora_weights(
@@ -4651,7 +4653,7 @@ def lora_state_dict(
46514653
allowed by Git.
46524654
subfolder (`str`, *optional*, defaults to `""`):
46534655
The subfolder location of a model file within a larger model repository on the Hub or locally.
4654-
4656+
return_lora_metadata: TODO
46554657
"""
46564658
# Load the main state dict first which has the LoRA layers for either of
46574659
# transformer and text encoder or both.
@@ -4664,6 +4666,7 @@ def lora_state_dict(
46644666
subfolder = kwargs.pop("subfolder", None)
46654667
weight_name = kwargs.pop("weight_name", None)
46664668
use_safetensors = kwargs.pop("use_safetensors", None)
4669+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
46674670

46684671
allow_pickle = False
46694672
if use_safetensors is None:
@@ -4672,7 +4675,7 @@ def lora_state_dict(
46724675

46734676
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
46744677

4675-
state_dict = _fetch_state_dict(
4678+
state_dict, metadata = _fetch_state_dict(
46764679
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
46774680
weight_name=weight_name,
46784681
use_safetensors=use_safetensors,
@@ -4698,7 +4701,8 @@ def lora_state_dict(
46984701
if non_diffusers:
46994702
state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict)
47004703

4701-
return state_dict
4704+
out = (state_dict, metadata) if return_lora_metadata else state_dict
4705+
return out
47024706

47034707
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
47044708
def load_lora_weights(
@@ -5714,7 +5718,7 @@ def lora_state_dict(
57145718
allowed by Git.
57155719
subfolder (`str`, *optional*, defaults to `""`):
57165720
The subfolder location of a model file within a larger model repository on the Hub or locally.
5717-
5721+
return_lora_metadata: TODO
57185722
"""
57195723
# Load the main state dict first which has the LoRA layers for either of
57205724
# transformer and text encoder or both.
@@ -5727,6 +5731,7 @@ def lora_state_dict(
57275731
subfolder = kwargs.pop("subfolder", None)
57285732
weight_name = kwargs.pop("weight_name", None)
57295733
use_safetensors = kwargs.pop("use_safetensors", None)
5734+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
57305735

57315736
allow_pickle = False
57325737
if use_safetensors is None:
@@ -5735,7 +5740,7 @@ def lora_state_dict(
57355740

57365741
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
57375742

5738-
state_dict = _fetch_state_dict(
5743+
state_dict, metadata = _fetch_state_dict(
57395744
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
57405745
weight_name=weight_name,
57415746
use_safetensors=use_safetensors,
@@ -5760,7 +5765,8 @@ def lora_state_dict(
57605765
if is_non_diffusers_format:
57615766
state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict)
57625767

5763-
return state_dict
5768+
out = (state_dict, metadata) if return_lora_metadata else state_dict
5769+
return out
57645770

57655771
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
57665772
def load_lora_weights(

0 commit comments

Comments
 (0)