Skip to content

Commit 95bb32f

Browse files
authored
Merge branch 'main' into gpu-test-pr
2 parents 4bb0e50 + 0404703 commit 95bb32f

File tree

5 files changed

+120
-130
lines changed

5 files changed

+120
-130
lines changed

docs/source/en/_toctree.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,10 @@
543543
title: Overview
544544
- local: api/schedulers/cm_stochastic_iterative
545545
title: CMStochasticIterativeScheduler
546+
- local: api/schedulers/ddim_cogvideox
547+
title: CogVideoXDDIMScheduler
548+
- local: api/schedulers/multistep_dpm_solver_cogvideox
549+
title: CogVideoXDPMScheduler
546550
- local: api/schedulers/consistency_decoder
547551
title: ConsistencyDecoderScheduler
548552
- local: api/schedulers/cosine_dpm
@@ -551,8 +555,6 @@
551555
title: DDIMInverseScheduler
552556
- local: api/schedulers/ddim
553557
title: DDIMScheduler
554-
- local: api/schedulers/ddim_cogvideox
555-
title: CogVideoXDDIMScheduler
556558
- local: api/schedulers/ddpm
557559
title: DDPMScheduler
558560
- local: api/schedulers/deis
@@ -565,8 +567,6 @@
565567
title: DPMSolverSDEScheduler
566568
- local: api/schedulers/singlestep_dpm_solver
567569
title: DPMSolverSinglestepScheduler
568-
- local: api/schedulers/multistep_dpm_solver_cogvideox
569-
title: CogVideoXDPMScheduler
570570
- local: api/schedulers/edm_multistep_dpm_solver
571571
title: EDMDPMSolverMultistepScheduler
572572
- local: api/schedulers/edm_euler

src/diffusers/models/attention_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1410,7 +1410,7 @@ class JointAttnProcessor2_0:
14101410

14111411
def __init__(self):
14121412
if not hasattr(F, "scaled_dot_product_attention"):
1413-
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1413+
raise ImportError("JointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
14141414

14151415
def __call__(
14161416
self,

src/diffusers/models/controlnets/controlnet_sd3.py

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,48 @@ class SD3ControlNetOutput(BaseOutput):
4040

4141

4242
class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
43+
r"""
44+
ControlNet model for [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
45+
46+
Parameters:
47+
sample_size (`int`, defaults to `128`):
48+
The width/height of the latents. This is fixed during training since it is used to learn a number of
49+
position embeddings.
50+
patch_size (`int`, defaults to `2`):
51+
Patch size to turn the input data into small patches.
52+
in_channels (`int`, defaults to `16`):
53+
The number of latent channels in the input.
54+
num_layers (`int`, defaults to `18`):
55+
The number of layers of transformer blocks to use.
56+
attention_head_dim (`int`, defaults to `64`):
57+
The number of channels in each head.
58+
num_attention_heads (`int`, defaults to `18`):
59+
The number of heads to use for multi-head attention.
60+
joint_attention_dim (`int`, defaults to `4096`):
61+
The embedding dimension to use for joint text-image attention.
62+
caption_projection_dim (`int`, defaults to `1152`):
63+
The embedding dimension of caption embeddings.
64+
pooled_projection_dim (`int`, defaults to `2048`):
65+
The embedding dimension of pooled text projections.
66+
out_channels (`int`, defaults to `16`):
67+
The number of latent channels in the output.
68+
pos_embed_max_size (`int`, defaults to `96`):
69+
The maximum latent height/width of positional embeddings.
70+
extra_conditioning_channels (`int`, defaults to `0`):
71+
The number of extra channels to use for conditioning for patch embedding.
72+
dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
73+
The number of dual-stream transformer blocks to use.
74+
qk_norm (`str`, *optional*, defaults to `None`):
75+
The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
76+
pos_embed_type (`str`, defaults to `"sincos"`):
77+
The type of positional embedding to use. Choose between `"sincos"` and `None`.
78+
use_pos_embed (`bool`, defaults to `True`):
79+
Whether to use positional embeddings.
80+
force_zeros_for_pooled_projection (`bool`, defaults to `True`):
81+
Whether to force zeros for pooled projection embeddings. This is handled in the pipelines by reading the
82+
config value of the ControlNet model.
83+
"""
84+
4385
_supports_gradient_checkpointing = True
4486

4587
@register_to_config
@@ -93,7 +135,7 @@ def __init__(
93135
JointTransformerBlock(
94136
dim=self.inner_dim,
95137
num_attention_heads=num_attention_heads,
96-
attention_head_dim=self.config.attention_head_dim,
138+
attention_head_dim=attention_head_dim,
97139
context_pre_only=False,
98140
qk_norm=qk_norm,
99141
use_dual_attention=True if i in dual_attention_layers else False,
@@ -108,7 +150,7 @@ def __init__(
108150
SD3SingleTransformerBlock(
109151
dim=self.inner_dim,
110152
num_attention_heads=num_attention_heads,
111-
attention_head_dim=self.config.attention_head_dim,
153+
attention_head_dim=attention_head_dim,
112154
)
113155
for _ in range(num_layers)
114156
]
@@ -297,28 +339,28 @@ def from_transformer(
297339

298340
def forward(
299341
self,
300-
hidden_states: torch.FloatTensor,
342+
hidden_states: torch.Tensor,
301343
controlnet_cond: torch.Tensor,
302344
conditioning_scale: float = 1.0,
303-
encoder_hidden_states: torch.FloatTensor = None,
304-
pooled_projections: torch.FloatTensor = None,
345+
encoder_hidden_states: torch.Tensor = None,
346+
pooled_projections: torch.Tensor = None,
305347
timestep: torch.LongTensor = None,
306348
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
307349
return_dict: bool = True,
308-
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
350+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
309351
"""
310352
The [`SD3Transformer2DModel`] forward method.
311353
312354
Args:
313-
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
355+
hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
314356
Input `hidden_states`.
315357
controlnet_cond (`torch.Tensor`):
316358
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
317359
conditioning_scale (`float`, defaults to `1.0`):
318360
The scale factor for ControlNet outputs.
319-
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
361+
encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
320362
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
321-
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
363+
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
322364
from the embeddings of input conditions.
323365
timestep ( `torch.LongTensor`):
324366
Used to indicate denoising step.
@@ -437,11 +479,11 @@ def __init__(self, controlnets):
437479

438480
def forward(
439481
self,
440-
hidden_states: torch.FloatTensor,
482+
hidden_states: torch.Tensor,
441483
controlnet_cond: List[torch.tensor],
442484
conditioning_scale: List[float],
443-
pooled_projections: torch.FloatTensor,
444-
encoder_hidden_states: torch.FloatTensor = None,
485+
pooled_projections: torch.Tensor,
486+
encoder_hidden_states: torch.Tensor = None,
445487
timestep: torch.LongTensor = None,
446488
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
447489
return_dict: bool = True,

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 9 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import numpy as np
1919
import torch
2020
import torch.nn as nn
21-
import torch.nn.functional as F
2221

2322
from ...configuration_utils import ConfigMixin, register_to_config
2423
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
@@ -32,7 +31,7 @@
3231
)
3332
from ...models.modeling_utils import ModelMixin
3433
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
35-
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
34+
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
3635
from ...utils.import_utils import is_torch_npu_available
3736
from ...utils.torch_utils import maybe_allow_in_graph
3837
from ..cache_utils import CacheMixin
@@ -45,20 +44,7 @@
4544

4645
@maybe_allow_in_graph
4746
class FluxSingleTransformerBlock(nn.Module):
48-
r"""
49-
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
50-
51-
Reference: https://arxiv.org/abs/2403.03206
52-
53-
Parameters:
54-
dim (`int`): The number of channels in the input and output.
55-
num_attention_heads (`int`): The number of heads to use for multi-head attention.
56-
attention_head_dim (`int`): The number of channels in each head.
57-
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
58-
processing of `context` conditions.
59-
"""
60-
61-
def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
47+
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
6248
super().__init__()
6349
self.mlp_hidden_dim = int(dim * mlp_ratio)
6450

@@ -68,9 +54,15 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
6854
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
6955

7056
if is_torch_npu_available():
57+
deprecation_message = (
58+
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
59+
"should be set explicitly using the `set_attn_processor` method."
60+
)
61+
deprecate("npu_processor", "0.34.0", deprecation_message)
7162
processor = FluxAttnProcessor2_0_NPU()
7263
else:
7364
processor = FluxAttnProcessor2_0()
65+
7466
self.attn = Attention(
7567
query_dim=dim,
7668
cross_attention_dim=None,
@@ -113,39 +105,14 @@ def forward(
113105

114106
@maybe_allow_in_graph
115107
class FluxTransformerBlock(nn.Module):
116-
r"""
117-
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
118-
119-
Reference: https://arxiv.org/abs/2403.03206
120-
121-
Args:
122-
dim (`int`):
123-
The embedding dimension of the block.
124-
num_attention_heads (`int`):
125-
The number of attention heads to use.
126-
attention_head_dim (`int`):
127-
The number of dimensions to use for each attention head.
128-
qk_norm (`str`, defaults to `"rms_norm"`):
129-
The normalization to use for the query and key tensors.
130-
eps (`float`, defaults to `1e-6`):
131-
The epsilon value to use for the normalization.
132-
"""
133-
134108
def __init__(
135109
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
136110
):
137111
super().__init__()
138112

139113
self.norm1 = AdaLayerNormZero(dim)
140-
141114
self.norm1_context = AdaLayerNormZero(dim)
142115

143-
if hasattr(F, "scaled_dot_product_attention"):
144-
processor = FluxAttnProcessor2_0()
145-
else:
146-
raise ValueError(
147-
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
148-
)
149116
self.attn = Attention(
150117
query_dim=dim,
151118
cross_attention_dim=None,
@@ -155,7 +122,7 @@ def __init__(
155122
out_dim=dim,
156123
context_pre_only=False,
157124
bias=True,
158-
processor=processor,
125+
processor=FluxAttnProcessor2_0(),
159126
qk_norm=qk_norm,
160127
eps=eps,
161128
)
@@ -166,10 +133,6 @@ def __init__(
166133
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
167134
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
168135

169-
# let chunk size default to None
170-
self._chunk_size = None
171-
self._chunk_dim = 0
172-
173136
def forward(
174137
self,
175138
hidden_states: torch.Tensor,

0 commit comments

Comments
 (0)