1515
1616import torch
1717import torch .nn as nn
18- import torch .nn .functional as F
1918
2019from ...configuration_utils import ConfigMixin , register_to_config
2120from ...loaders import FromOriginalModelMixin , PeftAdapterMixin , SD3Transformer2DLoadersMixin
3938
4039@maybe_allow_in_graph
4140class SD3SingleTransformerBlock (nn .Module ):
42- r"""
43- A Single Transformer block as part of the MMDiT architecture, used in Stable Diffusion 3 ControlNet.
44-
45- Reference: https://arxiv.org/abs/2403.03206
46-
47- Parameters:
48- dim (`int`): The number of channels in the input and output.
49- num_attention_heads (`int`): The number of heads to use for multi-head attention.
50- attention_head_dim (`int`): The number of channels in each head.
51- """
52-
5341 def __init__ (
5442 self ,
5543 dim : int ,
@@ -59,45 +47,31 @@ def __init__(
5947 super ().__init__ ()
6048
6149 self .norm1 = AdaLayerNormZero (dim )
62-
63- if hasattr (F , "scaled_dot_product_attention" ):
64- processor = JointAttnProcessor2_0 ()
65- else :
66- raise ValueError (
67- "The current PyTorch version does not support the `scaled_dot_product_attention` function."
68- )
69-
7050 self .attn = Attention (
7151 query_dim = dim ,
7252 dim_head = attention_head_dim ,
7353 heads = num_attention_heads ,
7454 out_dim = dim ,
7555 bias = True ,
76- processor = processor ,
56+ processor = JointAttnProcessor2_0 () ,
7757 eps = 1e-6 ,
7858 )
7959
8060 self .norm2 = nn .LayerNorm (dim , elementwise_affine = False , eps = 1e-6 )
8161 self .ff = FeedForward (dim = dim , dim_out = dim , activation_fn = "gelu-approximate" )
8262
8363 def forward (self , hidden_states : torch .Tensor , temb : torch .Tensor ):
64+ # 1. Attention
8465 norm_hidden_states , gate_msa , shift_mlp , scale_mlp , gate_mlp = self .norm1 (hidden_states , emb = temb )
85- # Attention.
86- attn_output = self .attn (
87- hidden_states = norm_hidden_states ,
88- encoder_hidden_states = None ,
89- )
90-
91- # Process attention outputs for the `hidden_states`.
66+ attn_output = self .attn (hidden_states = norm_hidden_states , encoder_hidden_states = None )
9267 attn_output = gate_msa .unsqueeze (1 ) * attn_output
9368 hidden_states = hidden_states + attn_output
9469
70+ # 2. Feed Forward
9571 norm_hidden_states = self .norm2 (hidden_states )
96- norm_hidden_states = norm_hidden_states * (1 + scale_mlp [:, None ]) + shift_mlp [:, None ]
97-
72+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp .unsqueeze (1 )) + shift_mlp .unsqueeze (1 )
9873 ff_output = self .ff (norm_hidden_states )
9974 ff_output = gate_mlp .unsqueeze (1 ) * ff_output
100-
10175 hidden_states = hidden_states + ff_output
10276
10377 return hidden_states
@@ -107,26 +81,40 @@ class SD3Transformer2DModel(
10781 ModelMixin , ConfigMixin , PeftAdapterMixin , FromOriginalModelMixin , SD3Transformer2DLoadersMixin
10882):
10983 """
110- The Transformer model introduced in Stable Diffusion 3.
111-
112- Reference: https://arxiv.org/abs/2403.03206
84+ The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
11385
11486 Parameters:
115- sample_size (`int`): The width of the latent images. This is fixed during training since
116- it is used to learn a number of position embeddings.
117- patch_size (`int`): Patch size to turn the input data into small patches.
118- in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
119- num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
120- attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
121- num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
122- cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
123- caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
124- pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
125- out_channels (`int`, defaults to 16): Number of output channels.
126-
87+ sample_size (`int`, defaults to `128`):
88+ The width/height of the latents. This is fixed during training since it is used to learn a number of
89+ position embeddings.
90+ patch_size (`int`, defaults to `2`):
91+ Patch size to turn the input data into small patches.
92+ in_channels (`int`, defaults to `16`):
93+ The number of latent channels in the input.
94+ num_layers (`int`, defaults to `18`):
95+ The number of layers of transformer blocks to use.
96+ attention_head_dim (`int`, defaults to `64`):
97+ The number of channels in each head.
98+ num_attention_heads (`int`, defaults to `18`):
99+ The number of heads to use for multi-head attention.
100+ joint_attention_dim (`int`, defaults to `4096`):
101+ The embedding dimension to use for joint text-image attention.
102+ caption_projection_dim (`int`, defaults to `1152`):
103+ The embedding dimension of caption embeddings.
104+ pooled_projection_dim (`int`, defaults to `2048`):
105+ The embedding dimension of pooled text projections.
106+ out_channels (`int`, defaults to `16`):
107+ The number of latent channels in the output.
108+ pos_embed_max_size (`int`, defaults to `96`):
109+ The maximum latent height/width of positional embeddings.
110+ dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
111+ The number of dual-stream transformer blocks to use.
112+ qk_norm (`str`, *optional*, defaults to `None`):
113+ The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
127114 """
128115
129116 _supports_gradient_checkpointing = True
117+ _no_split_modules = ["JointTransformerBlock" ]
130118 _skip_layerwise_casting_patterns = ["pos_embed" , "norm" ]
131119
132120 @register_to_config
@@ -149,36 +137,33 @@ def __init__(
149137 qk_norm : Optional [str ] = None ,
150138 ):
151139 super ().__init__ ()
152- default_out_channels = in_channels
153- self .out_channels = out_channels if out_channels is not None else default_out_channels
154- self .inner_dim = self .config .num_attention_heads * self .config .attention_head_dim
140+ self .out_channels = out_channels if out_channels is not None else in_channels
141+ self .inner_dim = num_attention_heads * attention_head_dim
155142
156143 self .pos_embed = PatchEmbed (
157- height = self . config . sample_size ,
158- width = self . config . sample_size ,
159- patch_size = self . config . patch_size ,
160- in_channels = self . config . in_channels ,
144+ height = sample_size ,
145+ width = sample_size ,
146+ patch_size = patch_size ,
147+ in_channels = in_channels ,
161148 embed_dim = self .inner_dim ,
162149 pos_embed_max_size = pos_embed_max_size , # hard-code for now.
163150 )
164151 self .time_text_embed = CombinedTimestepTextProjEmbeddings (
165- embedding_dim = self .inner_dim , pooled_projection_dim = self . config . pooled_projection_dim
152+ embedding_dim = self .inner_dim , pooled_projection_dim = pooled_projection_dim
166153 )
167- self .context_embedder = nn .Linear (self . config . joint_attention_dim , self . config . caption_projection_dim )
154+ self .context_embedder = nn .Linear (joint_attention_dim , caption_projection_dim )
168155
169- # `attention_head_dim` is doubled to account for the mixing.
170- # It needs to crafted when we get the actual checkpoints.
171156 self .transformer_blocks = nn .ModuleList (
172157 [
173158 JointTransformerBlock (
174159 dim = self .inner_dim ,
175- num_attention_heads = self . config . num_attention_heads ,
176- attention_head_dim = self . config . attention_head_dim ,
160+ num_attention_heads = num_attention_heads ,
161+ attention_head_dim = attention_head_dim ,
177162 context_pre_only = i == num_layers - 1 ,
178163 qk_norm = qk_norm ,
179164 use_dual_attention = True if i in dual_attention_layers else False ,
180165 )
181- for i in range (self . config . num_layers )
166+ for i in range (num_layers )
182167 ]
183168 )
184169
@@ -331,24 +316,24 @@ def unfuse_qkv_projections(self):
331316
332317 def forward (
333318 self ,
334- hidden_states : torch .FloatTensor ,
335- encoder_hidden_states : torch .FloatTensor = None ,
336- pooled_projections : torch .FloatTensor = None ,
319+ hidden_states : torch .Tensor ,
320+ encoder_hidden_states : torch .Tensor = None ,
321+ pooled_projections : torch .Tensor = None ,
337322 timestep : torch .LongTensor = None ,
338323 block_controlnet_hidden_states : List = None ,
339324 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
340325 return_dict : bool = True ,
341326 skip_layers : Optional [List [int ]] = None ,
342- ) -> Union [torch .FloatTensor , Transformer2DModelOutput ]:
327+ ) -> Union [torch .Tensor , Transformer2DModelOutput ]:
343328 """
344329 The [`SD3Transformer2DModel`] forward method.
345330
346331 Args:
347- hidden_states (`torch.FloatTensor ` of shape `(batch size, channel, height, width)`):
332+ hidden_states (`torch.Tensor ` of shape `(batch size, channel, height, width)`):
348333 Input `hidden_states`.
349- encoder_hidden_states (`torch.FloatTensor ` of shape `(batch size, sequence_len, embed_dims)`):
334+ encoder_hidden_states (`torch.Tensor ` of shape `(batch size, sequence_len, embed_dims)`):
350335 Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
351- pooled_projections (`torch.FloatTensor ` of shape `(batch_size, projection_dim)`):
336+ pooled_projections (`torch.Tensor ` of shape `(batch_size, projection_dim)`):
352337 Embeddings projected from the embeddings of input conditions.
353338 timestep (`torch.LongTensor`):
354339 Used to indicate denoising step.
0 commit comments