Skip to content

Commit f1be3eb

Browse files
committed
Merge branch 'chroma-fork' into chroma-final
2 parents 6735507 + de9a07f commit f1be3eb

File tree

15 files changed

+901
-236
lines changed

15 files changed

+901
-236
lines changed

docs/source/en/_toctree.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,8 @@
283283
title: AllegroTransformer3DModel
284284
- local: api/models/aura_flow_transformer2d
285285
title: AuraFlowTransformer2DModel
286+
- local: api/models/chroma_transformer
287+
title: ChromaTransformer2DModel
286288
- local: api/models/cogvideox_transformer3d
287289
title: CogVideoXTransformer3DModel
288290
- local: api/models/cogview3plus_transformer2d
@@ -405,6 +407,8 @@
405407
title: AutoPipeline
406408
- local: api/pipelines/blip_diffusion
407409
title: BLIP-Diffusion
410+
- local: api/pipelines/chroma
411+
title: Chroma
408412
- local: api/pipelines/cogvideox
409413
title: CogVideoX
410414
- local: api/pipelines/cogview3

examples/community/ip_adapter_face_id.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,7 @@ def load_ip_adapter_face_id(self, pretrained_model_name_or_path_or_dict, weight_
282282
revision = kwargs.pop("revision", None)
283283
subfolder = kwargs.pop("subfolder", None)
284284

285-
user_agent = {
286-
"file_type": "attn_procs_weights",
287-
"framework": "pytorch",
288-
}
285+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
289286
model_file = _get_model_file(
290287
pretrained_model_name_or_path_or_dict,
291288
weights_name=weight_name,

src/diffusers/loaders/ip_adapter.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,7 @@ def load_ip_adapter(
159159
" `low_cpu_mem_usage=False`."
160160
)
161161

162-
user_agent = {
163-
"file_type": "attn_procs_weights",
164-
"framework": "pytorch",
165-
}
162+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
166163
state_dicts = []
167164
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
168165
pretrained_model_name_or_path_or_dict, weight_name, subfolder
@@ -465,10 +462,7 @@ def load_ip_adapter(
465462
" `low_cpu_mem_usage=False`."
466463
)
467464

468-
user_agent = {
469-
"file_type": "attn_procs_weights",
470-
"framework": "pytorch",
471-
}
465+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
472466
state_dicts = []
473467
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
474468
pretrained_model_name_or_path_or_dict, weight_name, subfolder
@@ -750,10 +744,7 @@ def load_ip_adapter(
750744
" `low_cpu_mem_usage=False`."
751745
)
752746

753-
user_agent = {
754-
"file_type": "attn_procs_weights",
755-
"framework": "pytorch",
756-
}
747+
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
757748

758749
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
759750
model_file = _get_model_file(

src/diffusers/loaders/lora_base.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import copy
1616
import inspect
17+
import json
1718
import os
1819
from pathlib import Path
1920
from typing import Callable, Dict, List, Optional, Union
@@ -45,6 +46,7 @@
4546
set_adapter_layers,
4647
set_weights_and_activate_adapters,
4748
)
49+
from ..utils.state_dict_utils import _load_sft_state_dict_metadata
4850

4951

5052
if is_transformers_available():
@@ -62,6 +64,7 @@
6264

6365
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
6466
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
67+
LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata"
6568

6669

6770
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
@@ -206,6 +209,7 @@ def _fetch_state_dict(
206209
subfolder,
207210
user_agent,
208211
allow_pickle,
212+
metadata=None,
209213
):
210214
model_file = None
211215
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
@@ -236,11 +240,14 @@ def _fetch_state_dict(
236240
user_agent=user_agent,
237241
)
238242
state_dict = safetensors.torch.load_file(model_file, device="cpu")
243+
metadata = _load_sft_state_dict_metadata(model_file)
244+
239245
except (IOError, safetensors.SafetensorError) as e:
240246
if not allow_pickle:
241247
raise e
242248
# try loading non-safetensors weights
243249
model_file = None
250+
metadata = None
244251
pass
245252

246253
if model_file is None:
@@ -261,10 +268,11 @@ def _fetch_state_dict(
261268
user_agent=user_agent,
262269
)
263270
state_dict = load_state_dict(model_file)
271+
metadata = None
264272
else:
265273
state_dict = pretrained_model_name_or_path_or_dict
266274

267-
return state_dict
275+
return state_dict, metadata
268276

269277

270278
def _best_guess_weight_name(
@@ -306,6 +314,11 @@ def _best_guess_weight_name(
306314
return weight_name
307315

308316

317+
def _pack_dict_with_prefix(state_dict, prefix):
318+
sd_with_prefix = {f"{prefix}.{key}": value for key, value in state_dict.items()}
319+
return sd_with_prefix
320+
321+
309322
def _load_lora_into_text_encoder(
310323
state_dict,
311324
network_alphas,
@@ -317,10 +330,14 @@ def _load_lora_into_text_encoder(
317330
_pipeline=None,
318331
low_cpu_mem_usage=False,
319332
hotswap: bool = False,
333+
metadata=None,
320334
):
321335
if not USE_PEFT_BACKEND:
322336
raise ValueError("PEFT backend is required for this method.")
323337

338+
if network_alphas and metadata:
339+
raise ValueError("`network_alphas` and `metadata` cannot be specified both at the same time.")
340+
324341
peft_kwargs = {}
325342
if low_cpu_mem_usage:
326343
if not is_peft_version(">=", "0.13.1"):
@@ -349,6 +366,8 @@ def _load_lora_into_text_encoder(
349366
# Load the layers corresponding to text encoder and make necessary adjustments.
350367
if prefix is not None:
351368
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
369+
if metadata is not None:
370+
metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
352371

353372
if len(state_dict) > 0:
354373
logger.info(f"Loading {prefix}.")
@@ -376,7 +395,10 @@ def _load_lora_into_text_encoder(
376395
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
377396
network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}
378397

379-
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
398+
if metadata is not None:
399+
lora_config_kwargs = metadata
400+
else:
401+
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
380402

381403
if "use_dora" in lora_config_kwargs:
382404
if lora_config_kwargs["use_dora"]:
@@ -398,7 +420,10 @@ def _load_lora_into_text_encoder(
398420
if is_peft_version("<=", "0.13.2"):
399421
lora_config_kwargs.pop("lora_bias")
400422

401-
lora_config = LoraConfig(**lora_config_kwargs)
423+
try:
424+
lora_config = LoraConfig(**lora_config_kwargs)
425+
except TypeError as e:
426+
raise TypeError("`LoraConfig` class could not be instantiated.") from e
402427

403428
# adapter_name
404429
if adapter_name is None:
@@ -889,8 +914,7 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
889914
@staticmethod
890915
def pack_weights(layers, prefix):
891916
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
892-
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
893-
return layers_state_dict
917+
return _pack_dict_with_prefix(layers_weights, prefix)
894918

895919
@staticmethod
896920
def write_lora_layers(
@@ -900,16 +924,32 @@ def write_lora_layers(
900924
weight_name: str,
901925
save_function: Callable,
902926
safe_serialization: bool,
927+
lora_adapter_metadata: Optional[dict] = None,
903928
):
904929
if os.path.isfile(save_directory):
905930
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
906931
return
907932

933+
if lora_adapter_metadata and not safe_serialization:
934+
raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.")
935+
if lora_adapter_metadata and not isinstance(lora_adapter_metadata, dict):
936+
raise TypeError("`lora_adapter_metadata` must be of type `dict`.")
937+
908938
if save_function is None:
909939
if safe_serialization:
910940

911941
def save_function(weights, filename):
912-
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
942+
# Inject framework format.
943+
metadata = {"format": "pt"}
944+
if lora_adapter_metadata:
945+
for key, value in lora_adapter_metadata.items():
946+
if isinstance(value, set):
947+
lora_adapter_metadata[key] = list(value)
948+
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(
949+
lora_adapter_metadata, indent=2, sort_keys=True
950+
)
951+
952+
return safetensors.torch.save_file(weights, filename, metadata=metadata)
913953

914954
else:
915955
save_function = torch.save

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,9 +1605,18 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
16051605
if diff_keys:
16061606
for diff_k in diff_keys:
16071607
param = original_state_dict[diff_k]
1608+
# The magnitudes of the .diff-ending weights are very low (most are below 1e-4, some are upto 1e-3,
1609+
# and 2 of them are about 1.6e-2 [the case with AccVideo lora]). The low magnitudes mostly correspond
1610+
# to norm layers. Ignoring them is the best option at the moment until a better solution is found. It
1611+
# is okay to ignore because they do not affect the model output in a significant manner.
1612+
threshold = 1.6e-2
1613+
absdiff = param.abs().max() - param.abs().min()
16081614
all_zero = torch.all(param == 0).item()
1609-
if all_zero:
1610-
logger.debug(f"Removed {diff_k} key from the state dict as it's all zeros.")
1615+
all_absdiff_lower_than_threshold = absdiff < threshold
1616+
if all_zero or all_absdiff_lower_than_threshold:
1617+
logger.debug(
1618+
f"Removed {diff_k} key from the state dict as it's all zeros, or values lower than hardcoded threshold."
1619+
)
16111620
original_state_dict.pop(diff_k)
16121621

16131622
# For the `diff_b` keys, we treat them as lora_bias.
@@ -1655,12 +1664,16 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
16551664

16561665
# FFN
16571666
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
1658-
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop(
1659-
f"blocks.{i}.{o}.{lora_down_key}.weight"
1660-
)
1661-
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop(
1662-
f"blocks.{i}.{o}.{lora_up_key}.weight"
1663-
)
1667+
original_key = f"blocks.{i}.{o}.{lora_down_key}.weight"
1668+
converted_key = f"blocks.{i}.ffn.{c}.lora_A.weight"
1669+
if original_key in original_state_dict:
1670+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1671+
1672+
original_key = f"blocks.{i}.{o}.{lora_up_key}.weight"
1673+
converted_key = f"blocks.{i}.ffn.{c}.lora_B.weight"
1674+
if original_key in original_state_dict:
1675+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1676+
16641677
if f"blocks.{i}.{o}.diff_b" in original_state_dict:
16651678
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.bias"] = original_state_dict.pop(
16661679
f"blocks.{i}.{o}.diff_b"
@@ -1669,12 +1682,16 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
16691682
# Remaining.
16701683
if original_state_dict:
16711684
if any("time_projection" in k for k in original_state_dict):
1672-
converted_state_dict["condition_embedder.time_proj.lora_A.weight"] = original_state_dict.pop(
1673-
f"time_projection.1.{lora_down_key}.weight"
1674-
)
1675-
converted_state_dict["condition_embedder.time_proj.lora_B.weight"] = original_state_dict.pop(
1676-
f"time_projection.1.{lora_up_key}.weight"
1677-
)
1685+
original_key = f"time_projection.1.{lora_down_key}.weight"
1686+
converted_key = "condition_embedder.time_proj.lora_A.weight"
1687+
if original_key in original_state_dict:
1688+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1689+
1690+
original_key = f"time_projection.1.{lora_up_key}.weight"
1691+
converted_key = "condition_embedder.time_proj.lora_B.weight"
1692+
if original_key in original_state_dict:
1693+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1694+
16781695
if "time_projection.1.diff_b" in original_state_dict:
16791696
converted_state_dict["condition_embedder.time_proj.lora_B.bias"] = original_state_dict.pop(
16801697
"time_projection.1.diff_b"
@@ -1709,6 +1726,20 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
17091726
original_state_dict.pop(f"{text_time}.{b_n}.diff_b")
17101727
)
17111728

1729+
for img_ours, img_theirs in [
1730+
("ff.net.0.proj", "img_emb.proj.1"),
1731+
("ff.net.2", "img_emb.proj.3"),
1732+
]:
1733+
original_key = f"{img_theirs}.{lora_down_key}.weight"
1734+
converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_A.weight"
1735+
if original_key in original_state_dict:
1736+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1737+
1738+
original_key = f"{img_theirs}.{lora_up_key}.weight"
1739+
converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_B.weight"
1740+
if original_key in original_state_dict:
1741+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1742+
17121743
if len(original_state_dict) > 0:
17131744
diff = all(".diff" in k for k in original_state_dict)
17141745
if diff:

0 commit comments

Comments
 (0)