Skip to content

Commit cdc14c5

Browse files
Merge branch 'main' into auraflow-gguf
2 parents c6a4488 + b572635 commit cdc14c5

File tree

17 files changed

+390
-79
lines changed

17 files changed

+390
-79
lines changed

docs/source/en/tutorials/using_peft_for_inference.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ image
5656

5757
With the `adapter_name` parameter, it is really easy to use another adapter for inference! Load the [nerijs/pixel-art-xl](https://huggingface.co/nerijs/pixel-art-xl) adapter that has been fine-tuned to generate pixel art images and call it `"pixel"`.
5858

59-
The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter, but you can activate the `"pixel"` adapter with the [`~PeftAdapterMixin.set_adapters`] method:
59+
The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter, but you can activate the `"pixel"` adapter with the [`~loaders.peft.PeftAdapterMixin.set_adapters`] method:
6060

6161
```python
6262
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
@@ -85,7 +85,7 @@ By default, if the most up-to-date versions of PEFT and Transformers are detecte
8585

8686
You can also merge different adapter checkpoints for inference to blend their styles together.
8787

88-
Once again, use the [`~PeftAdapterMixin.set_adapters`] method to activate the `pixel` and `toy` adapters and specify the weights for how they should be merged.
88+
Once again, use the [`~loaders.peft.PeftAdapterMixin.set_adapters`] method to activate the `pixel` and `toy` adapters and specify the weights for how they should be merged.
8989

9090
```python
9191
pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0])
@@ -114,7 +114,7 @@ Impressive! As you can see, the model generated an image that mixed the characte
114114
> [!TIP]
115115
> Through its PEFT integration, Diffusers also offers more efficient merging methods which you can learn about in the [Merge LoRAs](../using-diffusers/merge_loras) guide!
116116
117-
To return to only using one adapter, use the [`~PeftAdapterMixin.set_adapters`] method to activate the `"toy"` adapter:
117+
To return to only using one adapter, use the [`~loaders.peft.PeftAdapterMixin.set_adapters`] method to activate the `"toy"` adapter:
118118

119119
```python
120120
pipe.set_adapters("toy")
@@ -127,7 +127,7 @@ image = pipe(
127127
image
128128
```
129129

130-
Or to disable all adapters entirely, use the [`~PeftAdapterMixin.disable_lora`] method to return the base model.
130+
Or to disable all adapters entirely, use the [`~loaders.peft.PeftAdapterMixin.disable_lora`] method to return the base model.
131131

132132
```python
133133
pipe.disable_lora()
@@ -141,7 +141,7 @@ image
141141

142142
### Customize adapters strength
143143

144-
For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~PeftAdapterMixin.set_adapters`].
144+
For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~loaders.peft.PeftAdapterMixin.set_adapters`].
145145

146146
For example, here's how you can turn on the adapter for the `down` parts, but turn it off for the `mid` and `up` parts:
147147
```python
@@ -214,7 +214,7 @@ list_adapters_component_wise
214214
{"text_encoder": ["toy", "pixel"], "unet": ["toy", "pixel"], "text_encoder_2": ["toy", "pixel"]}
215215
```
216216

217-
The [`~PeftAdapterMixin.delete_adapters`] function completely removes an adapter and their LoRA layers from a model.
217+
The [`~loaders.peft.PeftAdapterMixin.delete_adapters`] function completely removes an adapter and their LoRA layers from a model.
218218

219219
```py
220220
pipe.delete_adapters("toy")

examples/community/rerender_a_video.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,17 @@
3030
from diffusers.pipelines.controlnet.pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
3131
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
3232
from diffusers.schedulers import KarrasDiffusionSchedulers
33-
from diffusers.utils import BaseOutput, deprecate, logging
33+
from diffusers.utils import BaseOutput, deprecate, is_torch_xla_available, logging
3434
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
3535

3636

37+
if is_torch_xla_available():
38+
import torch_xla.core.xla_model as xm
39+
40+
XLA_AVAILABLE = True
41+
else:
42+
XLA_AVAILABLE = False
43+
3744
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3845

3946

@@ -775,7 +782,7 @@ def __call__(
775782
self.attn_state.reset()
776783

777784
# 4.1 prepare frames
778-
image = self.image_processor.preprocess(frames[0]).to(dtype=torch.float32)
785+
image = self.image_processor.preprocess(frames[0]).to(dtype=self.dtype)
779786
first_image = image[0] # C, H, W
780787

781788
# 4.2 Prepare controlnet_conditioning_image
@@ -919,8 +926,8 @@ def __call__(
919926
prev_image = frames[idx - 1]
920927
control_image = control_frames[idx]
921928
# 5.1 prepare frames
922-
image = self.image_processor.preprocess(image).to(dtype=torch.float32)
923-
prev_image = self.image_processor.preprocess(prev_image).to(dtype=torch.float32)
929+
image = self.image_processor.preprocess(image).to(dtype=self.dtype)
930+
prev_image = self.image_processor.preprocess(prev_image).to(dtype=self.dtype)
924931

925932
warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask(
926933
self.flow_model, first_image, image[0], first_result, False, self.device
@@ -1100,6 +1107,9 @@ def denoising_loop(latents, mask=None, xtrg=None, noise_rescale=None):
11001107
if callback is not None and i % callback_steps == 0:
11011108
callback(i, t, latents)
11021109

1110+
if XLA_AVAILABLE:
1111+
xm.mark_step()
1112+
11031113
return latents
11041114

11051115
if mask_start_t <= mask_end_t:

examples/flux-control/train_control_lora_flux.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -923,11 +923,28 @@ def load_model_hook(models, input_dir):
923923
transformer_ = model
924924
else:
925925
raise ValueError(f"unexpected save model: {model.__class__}")
926-
927926
else:
928927
transformer_ = FluxTransformer2DModel.from_pretrained(
929928
args.pretrained_model_name_or_path, subfolder="transformer"
930929
).to(accelerator.device, weight_dtype)
930+
931+
# Handle input dimension doubling before adding adapter
932+
with torch.no_grad():
933+
initial_input_channels = transformer_.config.in_channels
934+
new_linear = torch.nn.Linear(
935+
transformer_.x_embedder.in_features * 2,
936+
transformer_.x_embedder.out_features,
937+
bias=transformer_.x_embedder.bias is not None,
938+
dtype=transformer_.dtype,
939+
device=transformer_.device,
940+
)
941+
new_linear.weight.zero_()
942+
new_linear.weight[:, :initial_input_channels].copy_(transformer_.x_embedder.weight)
943+
if transformer_.x_embedder.bias is not None:
944+
new_linear.bias.copy_(transformer_.x_embedder.bias)
945+
transformer_.x_embedder = new_linear
946+
transformer_.register_to_config(in_channels=initial_input_channels * 2)
947+
931948
transformer_.add_adapter(transformer_lora_config)
932949

933950
lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir)

examples/text_to_image/train_text_to_image_sdxl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -919,7 +919,7 @@ def preprocess_train(examples):
919919
# fingerprint used by the cache for the other processes to load the result
920920
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
921921
new_fingerprint = Hasher.hash(args)
922-
new_fingerprint_for_vae = Hasher.hash(vae_path)
922+
new_fingerprint_for_vae = Hasher.hash((vae_path, args))
923923
train_dataset_with_embeddings = train_dataset.map(
924924
compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint
925925
)

src/diffusers/loaders/lora_pipeline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2466,7 +2466,9 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
24662466
continue
24672467

24682468
base_param_name = (
2469-
f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight"
2469+
f"{k.replace(prefix, '')}.base_layer.weight"
2470+
if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
2471+
else f"{k.replace(prefix, '')}.weight"
24702472
)
24712473
base_weight_param = transformer_state_dict[base_param_name]
24722474
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]

src/diffusers/models/embeddings.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1248,7 +1248,8 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
12481248
sin_out = []
12491249
pos = ids.float()
12501250
is_mps = ids.device.type == "mps"
1251-
freqs_dtype = torch.float32 if is_mps else torch.float64
1251+
is_npu = ids.device.type == "npu"
1252+
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
12521253
for i in range(n_axes):
12531254
cos, sin = get_1d_rotary_pos_embed(
12541255
self.axes_dim[i],

src/diffusers/models/transformers/sana_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,14 +250,14 @@ def __init__(
250250
inner_dim = num_attention_heads * attention_head_dim
251251

252252
# 1. Patch Embedding
253-
interpolation_scale = interpolation_scale if interpolation_scale is not None else max(sample_size // 64, 1)
254253
self.patch_embed = PatchEmbed(
255254
height=sample_size,
256255
width=sample_size,
257256
patch_size=patch_size,
258257
in_channels=in_channels,
259258
embed_dim=inner_dim,
260259
interpolation_scale=interpolation_scale,
260+
pos_embed_type="sincos" if interpolation_scale is not None else None,
261261
)
262262

263263
# 2. Additional condition embeddings

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 63 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,11 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
8585

8686
def forward(
8787
self,
88-
hidden_states: torch.FloatTensor,
89-
temb: torch.FloatTensor,
90-
image_rotary_emb=None,
91-
joint_attention_kwargs=None,
92-
):
88+
hidden_states: torch.Tensor,
89+
temb: torch.Tensor,
90+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
91+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
92+
) -> torch.Tensor:
9393
residual = hidden_states
9494
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
9595
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
@@ -117,15 +117,22 @@ class FluxTransformerBlock(nn.Module):
117117
118118
Reference: https://arxiv.org/abs/2403.03206
119119
120-
Parameters:
121-
dim (`int`): The number of channels in the input and output.
122-
num_attention_heads (`int`): The number of heads to use for multi-head attention.
123-
attention_head_dim (`int`): The number of channels in each head.
124-
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
125-
processing of `context` conditions.
120+
Args:
121+
dim (`int`):
122+
The embedding dimension of the block.
123+
num_attention_heads (`int`):
124+
The number of attention heads to use.
125+
attention_head_dim (`int`):
126+
The number of dimensions to use for each attention head.
127+
qk_norm (`str`, defaults to `"rms_norm"`):
128+
The normalization to use for the query and key tensors.
129+
eps (`float`, defaults to `1e-6`):
130+
The epsilon value to use for the normalization.
126131
"""
127132

128-
def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
133+
def __init__(
134+
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
135+
):
129136
super().__init__()
130137

131138
self.norm1 = AdaLayerNormZero(dim)
@@ -164,12 +171,12 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_no
164171

165172
def forward(
166173
self,
167-
hidden_states: torch.FloatTensor,
168-
encoder_hidden_states: torch.FloatTensor,
169-
temb: torch.FloatTensor,
170-
image_rotary_emb=None,
171-
joint_attention_kwargs=None,
172-
):
174+
hidden_states: torch.Tensor,
175+
encoder_hidden_states: torch.Tensor,
176+
temb: torch.Tensor,
177+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
178+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
179+
) -> Tuple[torch.Tensor, torch.Tensor]:
173180
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
174181

175182
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
@@ -227,16 +234,30 @@ class FluxTransformer2DModel(
227234
228235
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
229236
230-
Parameters:
231-
patch_size (`int`): Patch size to turn the input data into small patches.
232-
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
233-
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
234-
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
235-
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
236-
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
237-
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
238-
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
239-
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
237+
Args:
238+
patch_size (`int`, defaults to `1`):
239+
Patch size to turn the input data into small patches.
240+
in_channels (`int`, defaults to `64`):
241+
The number of channels in the input.
242+
out_channels (`int`, *optional*, defaults to `None`):
243+
The number of channels in the output. If not specified, it defaults to `in_channels`.
244+
num_layers (`int`, defaults to `19`):
245+
The number of layers of dual stream DiT blocks to use.
246+
num_single_layers (`int`, defaults to `38`):
247+
The number of layers of single stream DiT blocks to use.
248+
attention_head_dim (`int`, defaults to `128`):
249+
The number of dimensions to use for each attention head.
250+
num_attention_heads (`int`, defaults to `24`):
251+
The number of attention heads to use.
252+
joint_attention_dim (`int`, defaults to `4096`):
253+
The number of dimensions to use for the joint attention (embedding/channel dimension of
254+
`encoder_hidden_states`).
255+
pooled_projection_dim (`int`, defaults to `768`):
256+
The number of dimensions to use for the pooled projection.
257+
guidance_embeds (`bool`, defaults to `False`):
258+
Whether to use guidance embeddings for guidance-distilled variant of the model.
259+
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
260+
The dimensions to use for the rotary positional embeddings.
240261
"""
241262

242263
_supports_gradient_checkpointing = True
@@ -259,39 +280,39 @@ def __init__(
259280
):
260281
super().__init__()
261282
self.out_channels = out_channels or in_channels
262-
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
283+
self.inner_dim = num_attention_heads * attention_head_dim
263284

264285
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
265286

266287
text_time_guidance_cls = (
267288
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
268289
)
269290
self.time_text_embed = text_time_guidance_cls(
270-
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
291+
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
271292
)
272293

273-
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
274-
self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim)
294+
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
295+
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
275296

276297
self.transformer_blocks = nn.ModuleList(
277298
[
278299
FluxTransformerBlock(
279300
dim=self.inner_dim,
280-
num_attention_heads=self.config.num_attention_heads,
281-
attention_head_dim=self.config.attention_head_dim,
301+
num_attention_heads=num_attention_heads,
302+
attention_head_dim=attention_head_dim,
282303
)
283-
for i in range(self.config.num_layers)
304+
for _ in range(num_layers)
284305
]
285306
)
286307

287308
self.single_transformer_blocks = nn.ModuleList(
288309
[
289310
FluxSingleTransformerBlock(
290311
dim=self.inner_dim,
291-
num_attention_heads=self.config.num_attention_heads,
292-
attention_head_dim=self.config.attention_head_dim,
312+
num_attention_heads=num_attention_heads,
313+
attention_head_dim=attention_head_dim,
293314
)
294-
for i in range(self.config.num_single_layers)
315+
for _ in range(num_single_layers)
295316
]
296317
)
297318

@@ -418,16 +439,16 @@ def forward(
418439
controlnet_single_block_samples=None,
419440
return_dict: bool = True,
420441
controlnet_blocks_repeat: bool = False,
421-
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
442+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
422443
"""
423444
The [`FluxTransformer2DModel`] forward method.
424445
425446
Args:
426-
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
447+
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
427448
Input `hidden_states`.
428-
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
449+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
429450
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
430-
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
451+
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
431452
from the embeddings of input conditions.
432453
timestep ( `torch.LongTensor`):
433454
Used to indicate denoising step.

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,9 @@ def from_pipe(cls, pipeline, **kwargs):
528528
if k not in text_2_image_kwargs
529529
}
530530

531-
missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(text_2_image_kwargs.keys())
531+
missing_modules = (
532+
set(expected_modules) - set(text_2_image_cls._optional_components) - set(text_2_image_kwargs.keys())
533+
)
532534

533535
if len(missing_modules) > 0:
534536
raise ValueError(
@@ -838,7 +840,9 @@ def from_pipe(cls, pipeline, **kwargs):
838840
if k not in image_2_image_kwargs
839841
}
840842

841-
missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(image_2_image_kwargs.keys())
843+
missing_modules = (
844+
set(expected_modules) - set(image_2_image_cls._optional_components) - set(image_2_image_kwargs.keys())
845+
)
842846

843847
if len(missing_modules) > 0:
844848
raise ValueError(
@@ -1141,7 +1145,9 @@ def from_pipe(cls, pipeline, **kwargs):
11411145
if k not in inpainting_kwargs
11421146
}
11431147

1144-
missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(inpainting_kwargs.keys())
1148+
missing_modules = (
1149+
set(expected_modules) - set(inpainting_cls._optional_components) - set(inpainting_kwargs.keys())
1150+
)
11451151

11461152
if len(missing_modules) > 0:
11471153
raise ValueError(

0 commit comments

Comments
 (0)