Skip to content

Commit 79ed8c1

Browse files
committed
up
1 parent 30a8008 commit 79ed8c1

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

src/diffusers/models/transformers/transformer_lumina2.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,6 @@ def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor:
260260
return torch.cat(result, dim=-1).to(device)
261261

262262
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
263-
# Get batch info and dimensions
264263
batch_size, channels, height, width = hidden_states.shape
265264
p = self.patch_size
266265
post_patch_height, post_patch_width = height // p, width // p
@@ -276,11 +275,11 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
276275
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
277276

278277
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
279-
# Set caption positions
278+
# add caption position ids
280279
position_ids[i, :cap_seq_len, 0] = torch.arange(cap_seq_len, dtype=torch.int32, device=device)
281280
position_ids[i, cap_seq_len:seq_len, 0] = cap_seq_len
282281

283-
# Set image patch positions
282+
# add image position ids
284283
row_ids = (
285284
torch.arange(post_patch_height, dtype=torch.int32, device=device)
286285
.view(-1, 1)
@@ -296,10 +295,10 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
296295
position_ids[i, cap_seq_len:seq_len, 1] = row_ids
297296
position_ids[i, cap_seq_len:seq_len, 2] = col_ids
298297

299-
# Get frequencies
298+
# Get combined rotary embeddings
300299
freqs_cis = self._get_freqs_cis(position_ids)
301300

302-
# Split frequencies for captions and images
301+
# create separate rotary embeddings for captions and images
303302
cap_freqs_cis = torch.zeros(
304303
batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
305304
)
@@ -311,7 +310,7 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
311310
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
312311
img_freqs_cis[i, :image_seq_len] = freqs_cis[i, cap_seq_len:seq_len]
313312

314-
# patch embeddings
313+
# image patch embeddings
315314
hidden_states = (
316315
hidden_states.view(batch_size, channels, post_patch_height, p, post_patch_width, p)
317316
.permute(0, 2, 4, 3, 5, 1)

0 commit comments

Comments
 (0)