Skip to content

Commit 30a8008

Browse files
committed
up
1 parent bde26d3 commit 30a8008

File tree

1 file changed

+34
-39
lines changed

1 file changed

+34
-39
lines changed

src/diffusers/models/transformers/transformer_lumina2.py

Lines changed: 34 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -264,22 +264,21 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
264264
batch_size, channels, height, width = hidden_states.shape
265265
p = self.patch_size
266266
post_patch_height, post_patch_width = height // p, width // p
267-
num_patches = post_patch_height * post_patch_width
267+
image_seq_len = post_patch_height * post_patch_width
268268
device = hidden_states.device
269269

270-
# Get caption lengths and calculate max sequence length
270+
encoder_seq_len = attention_mask.shape[1]
271271
l_effective_cap_len = attention_mask.sum(dim=1).tolist()
272-
max_seq_len = max(l_effective_cap_len) + num_patches
272+
seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len]
273+
max_seq_len = max(seq_lengths)
273274

274275
# Create position IDs
275276
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
276277

277-
for i in range(batch_size):
278-
cap_len = l_effective_cap_len[i]
279-
278+
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
280279
# Set caption positions
281-
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
282-
position_ids[i, cap_len : cap_len + num_patches, 0] = cap_len
280+
position_ids[i, :cap_seq_len, 0] = torch.arange(cap_seq_len, dtype=torch.int32, device=device)
281+
position_ids[i, cap_seq_len:seq_len, 0] = cap_seq_len
283282

284283
# Set image patch positions
285284
row_ids = (
@@ -294,34 +293,33 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
294293
.repeat(post_patch_height, 1)
295294
.flatten()
296295
)
297-
position_ids[i, cap_len : cap_len + num_patches, 1] = row_ids
298-
position_ids[i, cap_len : cap_len + num_patches, 2] = col_ids
296+
position_ids[i, cap_seq_len:seq_len, 1] = row_ids
297+
position_ids[i, cap_seq_len:seq_len, 2] = col_ids
299298

300299
# Get frequencies
301300
freqs_cis = self._get_freqs_cis(position_ids)
302301

303302
# Split frequencies for captions and images
304303
cap_freqs_cis = torch.zeros(
305-
batch_size, attention_mask.shape[1], freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
304+
batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
305+
)
306+
img_freqs_cis = torch.zeros(
307+
batch_size, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
306308
)
307-
img_freqs_cis = torch.zeros(batch_size, num_patches, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype)
308309

309-
for i in range(batch_size):
310-
cap_len = l_effective_cap_len[i]
311-
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
312-
img_freqs_cis[i, :num_patches] = freqs_cis[i, cap_len : cap_len + num_patches]
310+
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
311+
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
312+
img_freqs_cis[i, :image_seq_len] = freqs_cis[i, cap_seq_len:seq_len]
313313

314314
# patch embeddings
315315
hidden_states = (
316-
hidden_states.view(
317-
batch_size, channels, post_patch_height, self.patch_size, post_patch_width, self.patch_size
318-
)
316+
hidden_states.view(batch_size, channels, post_patch_height, p, post_patch_width, p)
319317
.permute(0, 2, 4, 3, 5, 1)
320318
.flatten(3)
321319
.flatten(1, 2)
322320
)
323321

324-
return hidden_states, freqs_cis, cap_freqs_cis, img_freqs_cis
322+
return hidden_states, cap_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths
325323

326324

327325
class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
@@ -468,22 +466,17 @@ def forward(
468466
) -> Union[torch.Tensor, Transformer2DModelOutput]:
469467
# 1. Condition, positional & patch embedding
470468
batch_size, _, height, width = hidden_states.shape
471-
p = self.config.patch_size
472-
post_patch_height, post_patch_width = height // p, width // p
473-
num_patches = post_patch_height * post_patch_width
474-
475-
# effective_text_seq_lengths is based on actual caption length, so it's different for each prompt in a batch
476-
effective_encoder_seq_lengths = encoder_attention_mask.sum(dim=1).tolist()
477-
seq_lengths = [
478-
encoder_seq_len + num_patches for encoder_seq_len in effective_encoder_seq_lengths
479-
] # Add num_patches to each length
480-
max_seq_len = max(seq_lengths)
481469

482470
temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)
483471

484-
hidden_states, rotary_emb, context_rotary_emb, noise_rotary_emb = self.rope_embedder(
485-
hidden_states, encoder_attention_mask
486-
)
472+
(
473+
hidden_states,
474+
context_rotary_emb,
475+
noise_rotary_emb,
476+
rotary_emb,
477+
encoder_seq_lengths,
478+
seq_lengths,
479+
) = self.rope_embedder(hidden_states, encoder_attention_mask)
487480

488481
hidden_states = self.x_embedder(hidden_states)
489482

@@ -497,12 +490,13 @@ def forward(
497490
hidden_states = layer(hidden_states, None, noise_rotary_emb, temb)
498491

499492
# 3. Joint Transformer blocks
493+
max_seq_len = max(seq_lengths)
500494
attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
501495
joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
502-
for i, (effective_encoder_seq_len, seq_len) in enumerate(zip(effective_encoder_seq_lengths, seq_lengths)):
496+
for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
503497
attention_mask[i, :seq_len] = True
504-
joint_hidden_states[i, :effective_encoder_seq_len] = encoder_hidden_states[i, :effective_encoder_seq_len]
505-
joint_hidden_states[i, effective_encoder_seq_len:seq_len] = hidden_states[i]
498+
joint_hidden_states[i, :encoder_seq_len] = encoder_hidden_states[i, :encoder_seq_len]
499+
joint_hidden_states[i, encoder_seq_len:seq_len] = hidden_states[i]
506500

507501
hidden_states = joint_hidden_states
508502

@@ -520,11 +514,12 @@ def forward(
520514
hidden_states = self.norm_out(hidden_states, temb)
521515

522516
# 5. Unpatchify
517+
p = self.config.patch_size
523518
output = []
524-
for i, (effective_encoder_seq_len, seq_len) in enumerate(zip(effective_encoder_seq_lengths, seq_lengths)):
519+
for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
525520
output.append(
526-
hidden_states[i][effective_encoder_seq_len:seq_len]
527-
.view(post_patch_height, post_patch_width, p, p, self.out_channels)
521+
hidden_states[i][encoder_seq_len:seq_len]
522+
.view(height // p, width // p, p, p, self.out_channels)
528523
.permute(4, 0, 2, 1, 3)
529524
.flatten(3, 4)
530525
.flatten(1, 2)

0 commit comments

Comments
 (0)