Skip to content

Commit 43ff55f

Browse files
authored
Merge branch 'main' into reuse-attn-mixin
2 parents 45ef48b + 152f7ca commit 43ff55f

26 files changed

+844
-153
lines changed

docs/source/en/modular_diffusers/guiders.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ Change the [`~ComponentSpec.default_creation_method`] to `from_pretrained` and u
159159
```py
160160
guider_spec = t2i_pipeline.get_component_spec("guider")
161161
guider_spec.default_creation_method="from_pretrained"
162-
guider_spec.repo="YiYiXu/modular-loader-t2i-guider"
162+
guider_spec.pretrained_model_name_or_path="YiYiXu/modular-loader-t2i-guider"
163163
guider_spec.subfolder="pag_guider"
164164
pag_guider = guider_spec.load()
165165
t2i_pipeline.update_components(guider=pag_guider)

docs/source/en/modular_diffusers/modular_pipeline.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,14 +313,14 @@ unet_spec
313313
ComponentSpec(
314314
name='unet',
315315
type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,
316-
repo='RunDiffusion/Juggernaut-XL-v9',
316+
pretrained_model_name_or_path='RunDiffusion/Juggernaut-XL-v9',
317317
subfolder='unet',
318318
variant='fp16',
319319
default_creation_method='from_pretrained'
320320
)
321321

322322
# modify to load from a different repository
323-
unet_spec.repo = "stabilityai/stable-diffusion-xl-base-1.0"
323+
unet_spec.pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
324324

325325
# load component with modified spec
326326
unet = unet_spec.load(torch_dtype=torch.float16)

docs/source/zh/modular_diffusers/guiders.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ guider.push_to_hub("YiYiXu/modular-loader-t2i-guider", subfolder="pag_guider")
157157
```py
158158
guider_spec = t2i_pipeline.get_component_spec("guider")
159159
guider_spec.default_creation_method="from_pretrained"
160-
guider_spec.repo="YiYiXu/modular-loader-t2i-guider"
160+
guider_spec.pretrained_model_name_or_path="YiYiXu/modular-loader-t2i-guider"
161161
guider_spec.subfolder="pag_guider"
162162
pag_guider = guider_spec.load()
163163
t2i_pipeline.update_components(guider=pag_guider)

docs/source/zh/modular_diffusers/modular_pipeline.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,14 +313,14 @@ unet_spec
313313
ComponentSpec(
314314
name='unet',
315315
type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,
316-
repo='RunDiffusion/Juggernaut-XL-v9',
316+
pretrained_model_name_or_path='RunDiffusion/Juggernaut-XL-v9',
317317
subfolder='unet',
318318
variant='fp16',
319319
default_creation_method='from_pretrained'
320320
)
321321

322322
# 修改以从不同的仓库加载
323-
unet_spec.repo = "stabilityai/stable-diffusion-xl-base-1.0"
323+
unet_spec.pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
324324

325325
# 使用修改后的规范加载组件
326326
unet = unet_spec.load(torch_dtype=torch.float16)

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from huggingface_hub import create_repo, upload_folder
3838
from packaging import version
3939
from peft import LoraConfig
40-
from peft.utils import get_peft_model_state_dict
40+
from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict
4141
from torchvision import transforms
4242
from tqdm.auto import tqdm
4343
from transformers import CLIPTextModel, CLIPTokenizer
@@ -46,7 +46,12 @@
4646
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
4747
from diffusers.optimization import get_scheduler
4848
from diffusers.training_utils import cast_training_params, compute_snr
49-
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
49+
from diffusers.utils import (
50+
check_min_version,
51+
convert_state_dict_to_diffusers,
52+
convert_unet_state_dict_to_peft,
53+
is_wandb_available,
54+
)
5055
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
5156
from diffusers.utils.import_utils import is_xformers_available
5257
from diffusers.utils.torch_utils import is_compiled_module
@@ -708,6 +713,56 @@ def collate_fn(examples):
708713
num_workers=args.dataloader_num_workers,
709714
)
710715

716+
def save_model_hook(models, weights, output_dir):
717+
if accelerator.is_main_process:
718+
unet_lora_layers_to_save = None
719+
720+
for model in models:
721+
if isinstance(model, type(unwrap_model(unet))):
722+
unet_lora_layers_to_save = get_peft_model_state_dict(model)
723+
else:
724+
raise ValueError(f"Unexpected save model: {model.__class__}")
725+
726+
# make sure to pop weight so that corresponding model is not saved again
727+
weights.pop()
728+
729+
StableDiffusionPipeline.save_lora_weights(
730+
save_directory=output_dir,
731+
unet_lora_layers=unet_lora_layers_to_save,
732+
safe_serialization=True,
733+
)
734+
735+
def load_model_hook(models, input_dir):
736+
unet_ = None
737+
738+
while len(models) > 0:
739+
model = models.pop()
740+
if isinstance(model, type(unwrap_model(unet))):
741+
unet_ = model
742+
else:
743+
raise ValueError(f"unexpected save model: {model.__class__}")
744+
745+
# returns a tuple of state dictionary and network alphas
746+
lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir)
747+
748+
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
749+
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
750+
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
751+
752+
if incompatible_keys is not None:
753+
# check only for unexpected keys
754+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
755+
# throw warning if some unexpected keys are found and continue loading
756+
if unexpected_keys:
757+
logger.warning(
758+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
759+
f" {unexpected_keys}. "
760+
)
761+
762+
# Make sure the trainable params are in float32
763+
if args.mixed_precision in ["fp16"]:
764+
cast_training_params([unet_], dtype=torch.float32)
765+
711766
# Scheduler and math around the number of training steps.
712767
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
713768
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
@@ -732,6 +787,10 @@ def collate_fn(examples):
732787
unet, optimizer, train_dataloader, lr_scheduler
733788
)
734789

790+
# Register the hooks for efficient saving and loading of LoRA weights
791+
accelerator.register_save_state_pre_hook(save_model_hook)
792+
accelerator.register_load_state_pre_hook(load_model_hook)
793+
735794
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
736795
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
737796
if args.max_train_steps is None:
@@ -906,17 +965,6 @@ def collate_fn(examples):
906965
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
907966
accelerator.save_state(save_path)
908967

909-
unwrapped_unet = unwrap_model(unet)
910-
unet_lora_state_dict = convert_state_dict_to_diffusers(
911-
get_peft_model_state_dict(unwrapped_unet)
912-
)
913-
914-
StableDiffusionPipeline.save_lora_weights(
915-
save_directory=save_path,
916-
unet_lora_layers=unet_lora_state_dict,
917-
safe_serialization=True,
918-
)
919-
920968
logger.info(f"Saved state to {save_path}")
921969

922970
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,7 @@
991991
WanAnimateTransformer3DModel,
992992
WanTransformer3DModel,
993993
WanVACETransformer3DModel,
994+
ZImageTransformer2DModel,
994995
attention_backend,
995996
)
996997
from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks

src/diffusers/loaders/single_file_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,14 @@ def is_valid_url(url):
389389
return False
390390

391391

392+
def _is_single_file_path_or_url(pretrained_model_name_or_path):
393+
if not os.path.isfile(pretrained_model_name_or_path) or not is_valid_url(pretrained_model_name_or_path):
394+
return False
395+
396+
repo_id, weight_name = _extract_repo_id_and_weights_name(pretrained_model_name_or_path)
397+
return bool(repo_id and weight_name)
398+
399+
392400
def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
393401
if not is_valid_url(pretrained_model_name_or_path):
394402
raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.")
@@ -400,7 +408,6 @@ def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
400408
pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "")
401409
match = re.match(pattern, pretrained_model_name_or_path)
402410
if not match:
403-
logger.warning("Unable to identify the repo_id and weights_name from the provided URL.")
404411
return repo_id, weights_name
405412

406413
repo_id = f"{match.group(1)}/{match.group(2)}"

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,10 @@ def timestep_embedding(t, dim, max_period=10000):
6969

7070
def forward(self, t):
7171
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
72-
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
72+
weight_dtype = self.mlp[0].weight.dtype
73+
if weight_dtype.is_floating_point:
74+
t_freq = t_freq.to(weight_dtype)
75+
t_emb = self.mlp(t_freq)
7376
return t_emb
7477

7578

@@ -126,6 +129,10 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso
126129
dtype = query.dtype
127130
query, key = query.to(dtype), key.to(dtype)
128131

132+
# From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
133+
if attention_mask is not None and attention_mask.ndim == 2:
134+
attention_mask = attention_mask[:, None, None, :]
135+
129136
# Compute joint attention
130137
hidden_states = dispatch_attention_fn(
131138
query,
@@ -306,6 +313,10 @@ def __call__(self, ids: torch.Tensor):
306313
if self.freqs_cis is None:
307314
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
308315
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
316+
else:
317+
# Ensure freqs_cis are on the same device as ids
318+
if self.freqs_cis[0].device != device:
319+
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
309320

310321
result = []
311322
for i in range(len(self.axes_dims)):
@@ -317,6 +328,8 @@ def __call__(self, ids: torch.Tensor):
317328
class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
318329
_supports_gradient_checkpointing = True
319330
_no_split_modules = ["ZImageTransformerBlock"]
331+
_repeated_blocks = ["ZImageTransformerBlock"]
332+
_skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"] # precision sensitive layers
320333

321334
@register_to_config
322335
def __init__(
@@ -553,8 +566,6 @@ def forward(
553566
t = t * self.t_scale
554567
t = self.t_embedder(t)
555568

556-
adaln_input = t
557-
558569
(
559570
x,
560571
cap_feats,
@@ -572,6 +583,9 @@ def forward(
572583

573584
x = torch.cat(x, dim=0)
574585
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
586+
587+
# Match t_embedder output dtype to x for layerwise casting compatibility
588+
adaln_input = t.type_as(x)
575589
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
576590
x = list(x.split(x_item_seqlens, dim=0))
577591
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def init_pipeline(
360360
collection: Optional[str] = None,
361361
) -> "ModularPipeline":
362362
"""
363-
create a ModularPipeline, optionally accept modular_repo to load from hub.
363+
create a ModularPipeline, optionally accept pretrained_model_name_or_path to load from hub.
364364
"""
365365
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(self.model_name, ModularPipeline.__name__)
366366
diffusers_module = importlib.import_module("diffusers")
@@ -1645,8 +1645,8 @@ def from_pretrained(
16451645
pretrained_model_name_or_path (`str` or `os.PathLike`, optional):
16461646
Path to a pretrained pipeline configuration. It will first try to load config from
16471647
`modular_model_index.json`, then fallback to `model_index.json` for compatibility with standard
1648-
non-modular repositories. If the repo does not contain any pipeline config, it will be set to None
1649-
during initialization.
1648+
non-modular repositories. If the pretrained_model_name_or_path does not contain any pipeline config, it
1649+
will be set to None during initialization.
16501650
trust_remote_code (`bool`, optional):
16511651
Whether to trust remote code when loading the pipeline, need to be set to True if you want to create
16521652
pipeline blocks based on the custom code in `pretrained_model_name_or_path`
@@ -1807,7 +1807,7 @@ def register_components(self, **kwargs):
18071807
library, class_name = None, None
18081808

18091809
# extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config
1810-
# e.g. {"repo": "stabilityai/stable-diffusion-2-1",
1810+
# e.g. {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1",
18111811
# "type_hint": ("diffusers", "UNet2DConditionModel"),
18121812
# "subfolder": "unet",
18131813
# "variant": None,
@@ -2111,8 +2111,10 @@ def load_components(self, names: Optional[Union[List[str], str]] = None, **kwarg
21112111
**kwargs: additional kwargs to be passed to `from_pretrained()`.Can be:
21122112
- a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16
21132113
- a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32}
2114-
- if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`,
2115-
`variant`, `revision`, etc.
2114+
- if potentially override ComponentSpec if passed a different loading field in kwargs, e.g.
2115+
`pretrained_model_name_or_path`, `variant`, `revision`, etc.
2116+
- if potentially override ComponentSpec if passed a different loading field in kwargs, e.g.
2117+
`pretrained_model_name_or_path`, `variant`, `revision`, etc.
21162118
"""
21172119

21182120
if names is None:
@@ -2378,10 +2380,10 @@ def _component_spec_to_dict(component_spec: ComponentSpec) -> Any:
23782380
- "type_hint": Tuple[str, str]
23792381
Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel"))
23802382
- All loading fields defined by `component_spec.loading_fields()`, typically:
2381-
- "repo": Optional[str]
2382-
The model repository (e.g., "stabilityai/stable-diffusion-xl").
2383+
- "pretrained_model_name_or_path": Optional[str]
2384+
The model pretrained_model_name_or_pathsitory (e.g., "stabilityai/stable-diffusion-xl").
23832385
- "subfolder": Optional[str]
2384-
A subfolder within the repo where this component lives.
2386+
A subfolder within the pretrained_model_name_or_path where this component lives.
23852387
- "variant": Optional[str]
23862388
An optional variant identifier for the model.
23872389
- "revision": Optional[str]
@@ -2398,11 +2400,13 @@ def _component_spec_to_dict(component_spec: ComponentSpec) -> Any:
23982400
Example:
23992401
>>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec >>> from diffusers import
24002402
UNet2DConditionModel >>> spec = ComponentSpec(
2401-
... name="unet", ... type_hint=UNet2DConditionModel, ... config=None, ... repo="path/to/repo", ...
2402-
subfolder="subfolder", ... variant=None, ... revision=None, ...
2403-
default_creation_method="from_pretrained",
2403+
... name="unet", ... type_hint=UNet2DConditionModel, ... config=None, ...
2404+
pretrained_model_name_or_path="path/to/pretrained_model_name_or_path", ... subfolder="subfolder", ...
2405+
variant=None, ... revision=None, ... default_creation_method="from_pretrained",
24042406
... ) >>> ModularPipeline._component_spec_to_dict(spec) {
2405-
"type_hint": ("diffusers", "UNet2DConditionModel"), "repo": "path/to/repo", "subfolder": "subfolder",
2407+
"type_hint": ("diffusers", "UNet2DConditionModel"), "pretrained_model_name_or_path": "path/to/repo",
2408+
"subfolder": "subfolder", "variant": None, "revision": None, "type_hint": ("diffusers",
2409+
"UNet2DConditionModel"), "pretrained_model_name_or_path": "path/to/repo", "subfolder": "subfolder",
24062410
"variant": None, "revision": None,
24072411
}
24082412
"""
@@ -2432,10 +2436,10 @@ def _dict_to_component_spec(
24322436
- "type_hint": Tuple[str, str]
24332437
Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel"))
24342438
- All loading fields defined by `component_spec.loading_fields()`, typically:
2435-
- "repo": Optional[str]
2439+
- "pretrained_model_name_or_path": Optional[str]
24362440
The model repository (e.g., "stabilityai/stable-diffusion-xl").
24372441
- "subfolder": Optional[str]
2438-
A subfolder within the repo where this component lives.
2442+
A subfolder within the pretrained_model_name_or_path where this component lives.
24392443
- "variant": Optional[str]
24402444
An optional variant identifier for the model.
24412445
- "revision": Optional[str]
@@ -2452,11 +2456,20 @@ def _dict_to_component_spec(
24522456
ComponentSpec: A reconstructed ComponentSpec object.
24532457
24542458
Example:
2455-
>>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ... "repo":
2456-
"stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant": None, ... "revision": None, ...
2457-
} >>> ModularPipeline._dict_to_component_spec("unet", spec_dict) ComponentSpec(
2458-
name="unet", type_hint=UNet2DConditionModel, config=None, repo="stabilityai/stable-diffusion-xl",
2459-
subfolder="unet", variant=None, revision=None, default_creation_method="from_pretrained"
2459+
>>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ...
2460+
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant":
2461+
None, ... "revision": None, ... } >>> ModularPipeline._dict_to_component_spec("unet", spec_dict)
2462+
ComponentSpec(
2463+
name="unet", type_hint=UNet2DConditionModel, config=None,
2464+
pretrained_model_name_or_path="stabilityai/stable-diffusion-xl", subfolder="unet", variant=None,
2465+
revision=None, default_creation_method="from_pretrained"
2466+
>>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ...
2467+
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant":
2468+
None, ... "revision": None, ... } >>> ModularPipeline._dict_to_component_spec("unet", spec_dict)
2469+
ComponentSpec(
2470+
name="unet", type_hint=UNet2DConditionModel, config=None,
2471+
pretrained_model_name_or_path="stabilityai/stable-diffusion-xl", subfolder="unet", variant=None,
2472+
revision=None, default_creation_method="from_pretrained"
24602473
)
24612474
"""
24622475
# make a shallow copy so we can pop() safely

0 commit comments

Comments
 (0)