Skip to content

Commit 5a2a023

Browse files
committed
update
1 parent 40f5c97 commit 5a2a023

File tree

1 file changed

+44
-24
lines changed

1 file changed

+44
-24
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1612,6 +1612,7 @@ def lora_state_dict(
16121612
allowed by Git.
16131613
subfolder (`str`, *optional*, defaults to `""`):
16141614
The subfolder location of a model file within a larger model repository on the Hub or locally.
1615+
return_lora_metadata: TODO
16151616
16161617
"""
16171618
# Load the main state dict first which has the LoRA layers for either of
@@ -1625,18 +1626,16 @@ def lora_state_dict(
16251626
subfolder = kwargs.pop("subfolder", None)
16261627
weight_name = kwargs.pop("weight_name", None)
16271628
use_safetensors = kwargs.pop("use_safetensors", None)
1629+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
16281630

16291631
allow_pickle = False
16301632
if use_safetensors is None:
16311633
use_safetensors = True
16321634
allow_pickle = True
16331635

1634-
user_agent = {
1635-
"file_type": "attn_procs_weights",
1636-
"framework": "pytorch",
1637-
}
1636+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
16381637

1639-
state_dict = _fetch_state_dict(
1638+
state_dict, metadata = _fetch_state_dict(
16401639
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
16411640
weight_name=weight_name,
16421641
use_safetensors=use_safetensors,
@@ -1657,7 +1656,8 @@ def lora_state_dict(
16571656
logger.warning(warn_msg)
16581657
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
16591658

1660-
return state_dict
1659+
out = (state_dict, metadata) if return_lora_metadata else state_dict
1660+
return out
16611661

16621662
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
16631663
def load_lora_weights(
@@ -1702,7 +1702,8 @@ def load_lora_weights(
17021702
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
17031703

17041704
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
1705-
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
1705+
kwargs["return_lora_metadata"] = True
1706+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
17061707

17071708
is_correct_format = all("lora" in key for key in state_dict.keys())
17081709
if not is_correct_format:
@@ -1712,6 +1713,7 @@ def load_lora_weights(
17121713
state_dict,
17131714
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
17141715
adapter_name=adapter_name,
1716+
metadata=metadata,
17151717
_pipeline=self,
17161718
low_cpu_mem_usage=low_cpu_mem_usage,
17171719
hotswap=hotswap,
@@ -3058,7 +3060,8 @@ def load_lora_weights(
30583060
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
30593061

30603062
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3061-
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
3063+
kwargs["return_lora_metadata"] = True
3064+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
30623065

30633066
is_correct_format = all("lora" in key for key in state_dict.keys())
30643067
if not is_correct_format:
@@ -3068,6 +3071,7 @@ def load_lora_weights(
30683071
state_dict,
30693072
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
30703073
adapter_name=adapter_name,
3074+
metadata=metadata,
30713075
_pipeline=self,
30723076
low_cpu_mem_usage=low_cpu_mem_usage,
30733077
hotswap=hotswap,
@@ -3391,7 +3395,8 @@ def load_lora_weights(
33913395
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
33923396

33933397
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3394-
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
3398+
kwargs["return_lora_metadata"] = True
3399+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
33953400

33963401
is_correct_format = all("lora" in key for key in state_dict.keys())
33973402
if not is_correct_format:
@@ -3401,6 +3406,7 @@ def load_lora_weights(
34013406
state_dict,
34023407
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
34033408
adapter_name=adapter_name,
3409+
metadata=metadata,
34043410
_pipeline=self,
34053411
low_cpu_mem_usage=low_cpu_mem_usage,
34063412
hotswap=hotswap,
@@ -3635,7 +3641,7 @@ def lora_state_dict(
36353641
allowed by Git.
36363642
subfolder (`str`, *optional*, defaults to `""`):
36373643
The subfolder location of a model file within a larger model repository on the Hub or locally.
3638-
3644+
return_lora_metadata: TODO
36393645
"""
36403646
# Load the main state dict first which has the LoRA layers for either of
36413647
# transformer and text encoder or both.
@@ -3648,6 +3654,7 @@ def lora_state_dict(
36483654
subfolder = kwargs.pop("subfolder", None)
36493655
weight_name = kwargs.pop("weight_name", None)
36503656
use_safetensors = kwargs.pop("use_safetensors", None)
3657+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
36513658

36523659
allow_pickle = False
36533660
if use_safetensors is None:
@@ -3659,7 +3666,7 @@ def lora_state_dict(
36593666
"framework": "pytorch",
36603667
}
36613668

3662-
state_dict = _fetch_state_dict(
3669+
state_dict, metadata = _fetch_state_dict(
36633670
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
36643671
weight_name=weight_name,
36653672
use_safetensors=use_safetensors,
@@ -3684,7 +3691,8 @@ def lora_state_dict(
36843691
if is_non_diffusers_format:
36853692
state_dict = _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict)
36863693

3687-
return state_dict
3694+
out = (state_dict, metadata) if return_lora_metadata else state_dict
3695+
return out
36883696

36893697
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
36903698
def load_lora_weights(
@@ -3729,7 +3737,8 @@ def load_lora_weights(
37293737
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
37303738

37313739
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
3732-
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
3740+
kwargs["return_lora_metadata"] = True
3741+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
37333742

37343743
is_correct_format = all("lora" in key for key in state_dict.keys())
37353744
if not is_correct_format:
@@ -3739,6 +3748,7 @@ def load_lora_weights(
37393748
state_dict,
37403749
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
37413750
adapter_name=adapter_name,
3751+
metadata=metadata,
37423752
_pipeline=self,
37433753
low_cpu_mem_usage=low_cpu_mem_usage,
37443754
hotswap=hotswap,
@@ -4064,7 +4074,8 @@ def load_lora_weights(
40644074
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
40654075

40664076
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
4067-
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
4077+
kwargs["return_lora_metadata"] = True
4078+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
40684079

40694080
is_correct_format = all("lora" in key for key in state_dict.keys())
40704081
if not is_correct_format:
@@ -4074,6 +4085,7 @@ def load_lora_weights(
40744085
state_dict,
40754086
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
40764087
adapter_name=adapter_name,
4088+
metadata=metadata,
40774089
_pipeline=self,
40784090
low_cpu_mem_usage=low_cpu_mem_usage,
40794091
hotswap=hotswap,
@@ -4402,7 +4414,8 @@ def load_lora_weights(
44024414
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
44034415

44044416
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
4405-
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
4417+
kwargs["return_lora_metadata"] = True
4418+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
44064419

44074420
is_correct_format = all("lora" in key for key in state_dict.keys())
44084421
if not is_correct_format:
@@ -4412,6 +4425,7 @@ def load_lora_weights(
44124425
state_dict,
44134426
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
44144427
adapter_name=adapter_name,
4428+
metadata=metadata,
44154429
_pipeline=self,
44164430
low_cpu_mem_usage=low_cpu_mem_usage,
44174431
hotswap=hotswap,
@@ -4741,7 +4755,8 @@ def load_lora_weights(
47414755
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
47424756

47434757
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
4744-
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
4758+
kwargs["return_lora_metadata"] = True
4759+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
47454760

47464761
is_correct_format = all("lora" in key for key in state_dict.keys())
47474762
if not is_correct_format:
@@ -4751,6 +4766,7 @@ def load_lora_weights(
47514766
state_dict,
47524767
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
47534768
adapter_name=adapter_name,
4769+
metadata=metadata,
47544770
_pipeline=self,
47554771
low_cpu_mem_usage=low_cpu_mem_usage,
47564772
hotswap=hotswap,
@@ -5375,6 +5391,7 @@ def lora_state_dict(
53755391
allowed by Git.
53765392
subfolder (`str`, *optional*, defaults to `""`):
53775393
The subfolder location of a model file within a larger model repository on the Hub or locally.
5394+
return_lora_metadata: TODO
53785395
53795396
"""
53805397
# Load the main state dict first which has the LoRA layers for either of
@@ -5388,18 +5405,16 @@ def lora_state_dict(
53885405
subfolder = kwargs.pop("subfolder", None)
53895406
weight_name = kwargs.pop("weight_name", None)
53905407
use_safetensors = kwargs.pop("use_safetensors", None)
5408+
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
53915409

53925410
allow_pickle = False
53935411
if use_safetensors is None:
53945412
use_safetensors = True
53955413
allow_pickle = True
53965414

5397-
user_agent = {
5398-
"file_type": "attn_procs_weights",
5399-
"framework": "pytorch",
5400-
}
5415+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
54015416

5402-
state_dict = _fetch_state_dict(
5417+
state_dict, metadata = _fetch_state_dict(
54035418
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
54045419
weight_name=weight_name,
54055420
use_safetensors=use_safetensors,
@@ -5420,7 +5435,8 @@ def lora_state_dict(
54205435
logger.warning(warn_msg)
54215436
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
54225437

5423-
return state_dict
5438+
out = (state_dict, metadata) if return_lora_metadata else state_dict
5439+
return out
54245440

54255441
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
54265442
def load_lora_weights(
@@ -5465,7 +5481,8 @@ def load_lora_weights(
54655481
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
54665482

54675483
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
5468-
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
5484+
kwargs["return_lora_metadata"] = True
5485+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
54695486

54705487
is_correct_format = all("lora" in key for key in state_dict.keys())
54715488
if not is_correct_format:
@@ -5475,6 +5492,7 @@ def load_lora_weights(
54755492
state_dict,
54765493
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
54775494
adapter_name=adapter_name,
5495+
metadata=metadata,
54785496
_pipeline=self,
54795497
low_cpu_mem_usage=low_cpu_mem_usage,
54805498
hotswap=hotswap,
@@ -5803,7 +5821,8 @@ def load_lora_weights(
58035821
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
58045822

58055823
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
5806-
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
5824+
kwargs["return_lora_metadata"] = True
5825+
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
58075826

58085827
is_correct_format = all("lora" in key for key in state_dict.keys())
58095828
if not is_correct_format:
@@ -5813,6 +5832,7 @@ def load_lora_weights(
58135832
state_dict,
58145833
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
58155834
adapter_name=adapter_name,
5835+
metadata=metadata,
58165836
_pipeline=self,
58175837
low_cpu_mem_usage=low_cpu_mem_usage,
58185838
hotswap=hotswap,

0 commit comments

Comments
 (0)