Skip to content

Commit df96f35

Browse files
authored
Merge branch 'main' into update-doc-variable-names
2 parents 4366a06 + f7822ae commit df96f35

File tree

8 files changed

+214
-65
lines changed

8 files changed

+214
-65
lines changed

docs/source/en/tutorials/using_peft_for_inference.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ image
5656

5757
With the `adapter_name` parameter, it is really easy to use another adapter for inference! Load the [nerijs/pixel-art-xl](https://huggingface.co/nerijs/pixel-art-xl) adapter that has been fine-tuned to generate pixel art images and call it `"pixel"`.
5858

59-
The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter, but you can activate the `"pixel"` adapter with the [`~PeftAdapterMixin.set_adapters`] method:
59+
The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter, but you can activate the `"pixel"` adapter with the [`~loaders.peft.PeftAdapterMixin.set_adapters`] method:
6060

6161
```python
6262
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
@@ -85,7 +85,7 @@ By default, if the most up-to-date versions of PEFT and Transformers are detecte
8585

8686
You can also merge different adapter checkpoints for inference to blend their styles together.
8787

88-
Once again, use the [`~PeftAdapterMixin.set_adapters`] method to activate the `pixel` and `toy` adapters and specify the weights for how they should be merged.
88+
Once again, use the [`~loaders.peft.PeftAdapterMixin.set_adapters`] method to activate the `pixel` and `toy` adapters and specify the weights for how they should be merged.
8989

9090
```python
9191
pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0])
@@ -114,7 +114,7 @@ Impressive! As you can see, the model generated an image that mixed the characte
114114
> [!TIP]
115115
> Through its PEFT integration, Diffusers also offers more efficient merging methods which you can learn about in the [Merge LoRAs](../using-diffusers/merge_loras) guide!
116116
117-
To return to only using one adapter, use the [`~PeftAdapterMixin.set_adapters`] method to activate the `"toy"` adapter:
117+
To return to only using one adapter, use the [`~loaders.peft.PeftAdapterMixin.set_adapters`] method to activate the `"toy"` adapter:
118118

119119
```python
120120
pipe.set_adapters("toy")
@@ -127,7 +127,7 @@ image = pipe(
127127
image
128128
```
129129

130-
Or to disable all adapters entirely, use the [`~PeftAdapterMixin.disable_lora`] method to return the base model.
130+
Or to disable all adapters entirely, use the [`~loaders.peft.PeftAdapterMixin.disable_lora`] method to return the base model.
131131

132132
```python
133133
pipe.disable_lora()
@@ -141,7 +141,7 @@ image
141141

142142
### Customize adapters strength
143143

144-
For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~PeftAdapterMixin.set_adapters`].
144+
For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~loaders.peft.PeftAdapterMixin.set_adapters`].
145145

146146
For example, here's how you can turn on the adapter for the `down` parts, but turn it off for the `mid` and `up` parts:
147147
```python
@@ -214,7 +214,7 @@ list_adapters_component_wise
214214
{"text_encoder": ["toy", "pixel"], "unet": ["toy", "pixel"], "text_encoder_2": ["toy", "pixel"]}
215215
```
216216

217-
The [`~PeftAdapterMixin.delete_adapters`] function completely removes an adapter and their LoRA layers from a model.
217+
The [`~loaders.peft.PeftAdapterMixin.delete_adapters`] function completely removes an adapter and their LoRA layers from a model.
218218

219219
```py
220220
pipe.delete_adapters("toy")

examples/text_to_image/train_text_to_image_sdxl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -919,7 +919,7 @@ def preprocess_train(examples):
919919
# fingerprint used by the cache for the other processes to load the result
920920
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
921921
new_fingerprint = Hasher.hash(args)
922-
new_fingerprint_for_vae = Hasher.hash(vae_path)
922+
new_fingerprint_for_vae = Hasher.hash((vae_path, args))
923923
train_dataset_with_embeddings = train_dataset.map(
924924
compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint
925925
)

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 63 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,11 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
8585

8686
def forward(
8787
self,
88-
hidden_states: torch.FloatTensor,
89-
temb: torch.FloatTensor,
90-
image_rotary_emb=None,
91-
joint_attention_kwargs=None,
92-
):
88+
hidden_states: torch.Tensor,
89+
temb: torch.Tensor,
90+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
91+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
92+
) -> torch.Tensor:
9393
residual = hidden_states
9494
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
9595
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
@@ -117,15 +117,22 @@ class FluxTransformerBlock(nn.Module):
117117
118118
Reference: https://arxiv.org/abs/2403.03206
119119
120-
Parameters:
121-
dim (`int`): The number of channels in the input and output.
122-
num_attention_heads (`int`): The number of heads to use for multi-head attention.
123-
attention_head_dim (`int`): The number of channels in each head.
124-
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
125-
processing of `context` conditions.
120+
Args:
121+
dim (`int`):
122+
The embedding dimension of the block.
123+
num_attention_heads (`int`):
124+
The number of attention heads to use.
125+
attention_head_dim (`int`):
126+
The number of dimensions to use for each attention head.
127+
qk_norm (`str`, defaults to `"rms_norm"`):
128+
The normalization to use for the query and key tensors.
129+
eps (`float`, defaults to `1e-6`):
130+
The epsilon value to use for the normalization.
126131
"""
127132

128-
def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
133+
def __init__(
134+
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
135+
):
129136
super().__init__()
130137

131138
self.norm1 = AdaLayerNormZero(dim)
@@ -164,12 +171,12 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_no
164171

165172
def forward(
166173
self,
167-
hidden_states: torch.FloatTensor,
168-
encoder_hidden_states: torch.FloatTensor,
169-
temb: torch.FloatTensor,
170-
image_rotary_emb=None,
171-
joint_attention_kwargs=None,
172-
):
174+
hidden_states: torch.Tensor,
175+
encoder_hidden_states: torch.Tensor,
176+
temb: torch.Tensor,
177+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
178+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
179+
) -> Tuple[torch.Tensor, torch.Tensor]:
173180
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
174181

175182
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
@@ -227,16 +234,30 @@ class FluxTransformer2DModel(
227234
228235
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
229236
230-
Parameters:
231-
patch_size (`int`): Patch size to turn the input data into small patches.
232-
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
233-
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
234-
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
235-
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
236-
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
237-
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
238-
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
239-
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
237+
Args:
238+
patch_size (`int`, defaults to `1`):
239+
Patch size to turn the input data into small patches.
240+
in_channels (`int`, defaults to `64`):
241+
The number of channels in the input.
242+
out_channels (`int`, *optional*, defaults to `None`):
243+
The number of channels in the output. If not specified, it defaults to `in_channels`.
244+
num_layers (`int`, defaults to `19`):
245+
The number of layers of dual stream DiT blocks to use.
246+
num_single_layers (`int`, defaults to `38`):
247+
The number of layers of single stream DiT blocks to use.
248+
attention_head_dim (`int`, defaults to `128`):
249+
The number of dimensions to use for each attention head.
250+
num_attention_heads (`int`, defaults to `24`):
251+
The number of attention heads to use.
252+
joint_attention_dim (`int`, defaults to `4096`):
253+
The number of dimensions to use for the joint attention (embedding/channel dimension of
254+
`encoder_hidden_states`).
255+
pooled_projection_dim (`int`, defaults to `768`):
256+
The number of dimensions to use for the pooled projection.
257+
guidance_embeds (`bool`, defaults to `False`):
258+
Whether to use guidance embeddings for guidance-distilled variant of the model.
259+
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
260+
The dimensions to use for the rotary positional embeddings.
240261
"""
241262

242263
_supports_gradient_checkpointing = True
@@ -259,39 +280,39 @@ def __init__(
259280
):
260281
super().__init__()
261282
self.out_channels = out_channels or in_channels
262-
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
283+
self.inner_dim = num_attention_heads * attention_head_dim
263284

264285
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
265286

266287
text_time_guidance_cls = (
267288
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
268289
)
269290
self.time_text_embed = text_time_guidance_cls(
270-
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
291+
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
271292
)
272293

273-
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
274-
self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim)
294+
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
295+
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
275296

276297
self.transformer_blocks = nn.ModuleList(
277298
[
278299
FluxTransformerBlock(
279300
dim=self.inner_dim,
280-
num_attention_heads=self.config.num_attention_heads,
281-
attention_head_dim=self.config.attention_head_dim,
301+
num_attention_heads=num_attention_heads,
302+
attention_head_dim=attention_head_dim,
282303
)
283-
for i in range(self.config.num_layers)
304+
for _ in range(num_layers)
284305
]
285306
)
286307

287308
self.single_transformer_blocks = nn.ModuleList(
288309
[
289310
FluxSingleTransformerBlock(
290311
dim=self.inner_dim,
291-
num_attention_heads=self.config.num_attention_heads,
292-
attention_head_dim=self.config.attention_head_dim,
312+
num_attention_heads=num_attention_heads,
313+
attention_head_dim=attention_head_dim,
293314
)
294-
for i in range(self.config.num_single_layers)
315+
for _ in range(num_single_layers)
295316
]
296317
)
297318

@@ -418,16 +439,16 @@ def forward(
418439
controlnet_single_block_samples=None,
419440
return_dict: bool = True,
420441
controlnet_blocks_repeat: bool = False,
421-
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
442+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
422443
"""
423444
The [`FluxTransformer2DModel`] forward method.
424445
425446
Args:
426-
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
447+
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
427448
Input `hidden_states`.
428-
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
449+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
429450
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
430-
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
451+
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
431452
from the embeddings of input conditions.
432453
timestep ( `torch.LongTensor`):
433454
Used to indicate denoising step.

0 commit comments

Comments
 (0)