Skip to content

Commit a9a5b14

Browse files
authored
[Core] refactor transformers 2d into multiple init variants. (#7491)
* refactor transformers 2d into multiple legacy variants. * fix: init. * fix recursive init. * add inits. * make transformer block creation more modular. * complete refactor. * remove forward * debug * remove legacy blocks and refactor within the module itself. * remove print * guard caption projection * remove fetcher. * reduce the number of args. * fix: norm_type * group variables that are shared. * remove _get_transformer_blocks * harmonize the init function signatures. * transformer_blocks to common * repeat .
1 parent aa19025 commit a9a5b14

File tree

1 file changed

+149
-78
lines changed

1 file changed

+149
-78
lines changed

src/diffusers/models/transformers/transformer_2d.py

Lines changed: 149 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def __init__(
102102
interpolation_scale: float = None,
103103
):
104104
super().__init__()
105+
106+
# Validate inputs.
105107
if patch_size is not None:
106108
if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]:
107109
raise NotImplementedError(
@@ -112,10 +114,16 @@ def __init__(
112114
f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
113115
)
114116

117+
# Set some common variables used across the board.
115118
self.use_linear_projection = use_linear_projection
119+
self.interpolation_scale = interpolation_scale
120+
self.caption_channels = caption_channels
116121
self.num_attention_heads = num_attention_heads
117122
self.attention_head_dim = attention_head_dim
118-
inner_dim = num_attention_heads * attention_head_dim
123+
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
124+
self.in_channels = in_channels
125+
self.out_channels = in_channels if out_channels is None else out_channels
126+
self.gradient_checkpointing = False
119127

120128
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
121129
# Define whether input is continuous or discrete depending on configuration
@@ -150,104 +158,167 @@ def __init__(
150158
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
151159
)
152160

153-
# 2. Define input layers
161+
# 2. Initialize the right blocks.
162+
# These functions follow a common structure:
163+
# a. Initialize the input blocks. b. Initialize the transformer blocks.
164+
# c. Initialize the output blocks and other projection blocks when necessary.
154165
if self.is_input_continuous:
155-
self.in_channels = in_channels
156-
157-
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
158-
if use_linear_projection:
159-
self.proj_in = nn.Linear(in_channels, inner_dim)
160-
else:
161-
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
166+
self._init_continuous_input(norm_type=norm_type)
162167
elif self.is_input_vectorized:
163-
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
164-
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
168+
self._init_vectorized_inputs(norm_type=norm_type)
169+
elif self.is_input_patches:
170+
self._init_patched_inputs(norm_type=norm_type)
165171

166-
self.height = sample_size
167-
self.width = sample_size
168-
self.num_vector_embeds = num_vector_embeds
169-
self.num_latent_pixels = self.height * self.width
172+
def _init_continuous_input(self, norm_type):
173+
self.norm = torch.nn.GroupNorm(
174+
num_groups=self.config.norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True
175+
)
176+
if self.use_linear_projection:
177+
self.proj_in = torch.nn.Linear(self.in_channels, self.inner_dim)
178+
else:
179+
self.proj_in = torch.nn.Conv2d(self.in_channels, self.inner_dim, kernel_size=1, stride=1, padding=0)
170180

171-
self.latent_image_embedding = ImagePositionalEmbeddings(
172-
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
173-
)
174-
elif self.is_input_patches:
175-
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
181+
self.transformer_blocks = nn.ModuleList(
182+
[
183+
BasicTransformerBlock(
184+
self.inner_dim,
185+
self.config.num_attention_heads,
186+
self.config.attention_head_dim,
187+
dropout=self.config.dropout,
188+
cross_attention_dim=self.config.cross_attention_dim,
189+
activation_fn=self.config.activation_fn,
190+
num_embeds_ada_norm=self.config.num_embeds_ada_norm,
191+
attention_bias=self.config.attention_bias,
192+
only_cross_attention=self.config.only_cross_attention,
193+
double_self_attention=self.config.double_self_attention,
194+
upcast_attention=self.config.upcast_attention,
195+
norm_type=norm_type,
196+
norm_elementwise_affine=self.config.norm_elementwise_affine,
197+
norm_eps=self.config.norm_eps,
198+
attention_type=self.config.attention_type,
199+
)
200+
for _ in range(self.config.num_layers)
201+
]
202+
)
176203

177-
self.height = sample_size
178-
self.width = sample_size
204+
if self.use_linear_projection:
205+
self.proj_out = torch.nn.Linear(self.inner_dim, self.out_channels)
206+
else:
207+
self.proj_out = torch.nn.Conv2d(self.inner_dim, self.out_channels, kernel_size=1, stride=1, padding=0)
179208

180-
self.patch_size = patch_size
181-
interpolation_scale = (
182-
interpolation_scale if interpolation_scale is not None else max(self.config.sample_size // 64, 1)
183-
)
184-
self.pos_embed = PatchEmbed(
185-
height=sample_size,
186-
width=sample_size,
187-
patch_size=patch_size,
188-
in_channels=in_channels,
189-
embed_dim=inner_dim,
190-
interpolation_scale=interpolation_scale,
191-
)
209+
def _init_vectorized_inputs(self, norm_type):
210+
assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
211+
assert (
212+
self.config.num_vector_embeds is not None
213+
), "Transformer2DModel over discrete input must provide num_embed"
214+
215+
self.height = self.config.sample_size
216+
self.width = self.config.sample_size
217+
self.num_latent_pixels = self.height * self.width
218+
219+
self.latent_image_embedding = ImagePositionalEmbeddings(
220+
num_embed=self.config.num_vector_embeds, embed_dim=self.inner_dim, height=self.height, width=self.width
221+
)
192222

193-
# 3. Define transformers blocks
194223
self.transformer_blocks = nn.ModuleList(
195224
[
196225
BasicTransformerBlock(
197-
inner_dim,
198-
num_attention_heads,
199-
attention_head_dim,
200-
dropout=dropout,
201-
cross_attention_dim=cross_attention_dim,
202-
activation_fn=activation_fn,
203-
num_embeds_ada_norm=num_embeds_ada_norm,
204-
attention_bias=attention_bias,
205-
only_cross_attention=only_cross_attention,
206-
double_self_attention=double_self_attention,
207-
upcast_attention=upcast_attention,
226+
self.inner_dim,
227+
self.config.num_attention_heads,
228+
self.config.attention_head_dim,
229+
dropout=self.config.dropout,
230+
cross_attention_dim=self.config.cross_attention_dim,
231+
activation_fn=self.config.activation_fn,
232+
num_embeds_ada_norm=self.config.num_embeds_ada_norm,
233+
attention_bias=self.config.attention_bias,
234+
only_cross_attention=self.config.only_cross_attention,
235+
double_self_attention=self.config.double_self_attention,
236+
upcast_attention=self.config.upcast_attention,
208237
norm_type=norm_type,
209-
norm_elementwise_affine=norm_elementwise_affine,
210-
norm_eps=norm_eps,
211-
attention_type=attention_type,
238+
norm_elementwise_affine=self.config.norm_elementwise_affine,
239+
norm_eps=self.config.norm_eps,
240+
attention_type=self.config.attention_type,
212241
)
213-
for d in range(num_layers)
242+
for _ in range(self.config.num_layers)
214243
]
215244
)
216245

217-
# 4. Define output layers
218-
self.out_channels = in_channels if out_channels is None else out_channels
219-
if self.is_input_continuous:
220-
# TODO: should use out_channels for continuous projections
221-
if use_linear_projection:
222-
self.proj_out = nn.Linear(inner_dim, in_channels)
223-
else:
224-
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
225-
elif self.is_input_vectorized:
226-
self.norm_out = nn.LayerNorm(inner_dim)
227-
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
228-
elif self.is_input_patches and norm_type != "ada_norm_single":
229-
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
230-
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
231-
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
232-
elif self.is_input_patches and norm_type == "ada_norm_single":
233-
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
234-
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
235-
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
236-
237-
# 5. PixArt-Alpha blocks.
246+
self.norm_out = nn.LayerNorm(self.inner_dim)
247+
self.out = nn.Linear(self.inner_dim, self.config.num_vector_embeds - 1)
248+
249+
def _init_patched_inputs(self, norm_type):
250+
assert self.config.sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
251+
252+
self.height = self.config.sample_size
253+
self.width = self.config.sample_size
254+
255+
self.patch_size = self.config.patch_size
256+
interpolation_scale = (
257+
self.config.interpolation_scale
258+
if self.config.interpolation_scale is not None
259+
else max(self.config.sample_size // 64, 1)
260+
)
261+
self.pos_embed = PatchEmbed(
262+
height=self.config.sample_size,
263+
width=self.config.sample_size,
264+
patch_size=self.config.patch_size,
265+
in_channels=self.in_channels,
266+
embed_dim=self.inner_dim,
267+
interpolation_scale=interpolation_scale,
268+
)
269+
270+
self.transformer_blocks = nn.ModuleList(
271+
[
272+
BasicTransformerBlock(
273+
self.inner_dim,
274+
self.config.num_attention_heads,
275+
self.config.attention_head_dim,
276+
dropout=self.config.dropout,
277+
cross_attention_dim=self.config.cross_attention_dim,
278+
activation_fn=self.config.activation_fn,
279+
num_embeds_ada_norm=self.config.num_embeds_ada_norm,
280+
attention_bias=self.config.attention_bias,
281+
only_cross_attention=self.config.only_cross_attention,
282+
double_self_attention=self.config.double_self_attention,
283+
upcast_attention=self.config.upcast_attention,
284+
norm_type=norm_type,
285+
norm_elementwise_affine=self.config.norm_elementwise_affine,
286+
norm_eps=self.config.norm_eps,
287+
attention_type=self.config.attention_type,
288+
)
289+
for _ in range(self.config.num_layers)
290+
]
291+
)
292+
293+
if self.config.norm_type != "ada_norm_single":
294+
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
295+
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
296+
self.proj_out_2 = nn.Linear(
297+
self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
298+
)
299+
elif self.config.norm_type == "ada_norm_single":
300+
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
301+
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
302+
self.proj_out = nn.Linear(
303+
self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
304+
)
305+
306+
# PixArt-Alpha blocks.
238307
self.adaln_single = None
239308
self.use_additional_conditions = False
240-
if norm_type == "ada_norm_single":
309+
if self.config.norm_type == "ada_norm_single":
241310
self.use_additional_conditions = self.config.sample_size == 128
242311
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
243312
# additional conditions until we find better name
244-
self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
313+
self.adaln_single = AdaLayerNormSingle(
314+
self.inner_dim, use_additional_conditions=self.use_additional_conditions
315+
)
245316

246317
self.caption_projection = None
247-
if caption_channels is not None:
248-
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
249-
250-
self.gradient_checkpointing = False
318+
if self.caption_channels is not None:
319+
self.caption_projection = PixArtAlphaTextProjection(
320+
in_features=self.caption_channels, hidden_size=self.inner_dim
321+
)
251322

252323
def _set_gradient_checkpointing(self, module, value=False):
253324
if hasattr(module, "gradient_checkpointing"):
@@ -361,7 +432,7 @@ def forward(
361432
)
362433

363434
# 2. Blocks
364-
if self.caption_projection is not None:
435+
if self.is_input_patches and self.caption_projection is not None:
365436
batch_size = hidden_states.shape[0]
366437
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
367438
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])

0 commit comments

Comments
 (0)