Skip to content

Commit 2bf6677

Browse files
committed
fix: pos ids
1 parent b6e8793 commit 2bf6677

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/diffusers/models/transformers/transformer_lumina2_accessory.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,9 @@ def forward(
132132
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
133133
# add caption position ids
134134
position_ids[i, :cap_seq_len, 0] = torch.arange(cap_seq_len, dtype=torch.int32, device=device)
135-
position_ids[i, cap_seq_len:seq_len, 0] = cap_seq_len
136135

137136
# add condition image position ids
137+
position_ids[i, cap_seq_len:cond_image_seq_len, 0] = cap_seq_len
138138
cond_row_ids = (
139139
torch.arange(post_patch_cond_height, dtype=torch.int32, device=device)
140140
.view(-1, 1)
@@ -159,6 +159,8 @@ def forward(
159159
)
160160

161161
# add image position ids
162+
position_ids[i, cap_seq_len + cond_image_seq_len : seq_len, 0] = cap_seq_len + 1
163+
162164
row_ids = (
163165
torch.arange(post_patch_height, dtype=torch.int32, device=device)
164166
.view(-1, 1)

0 commit comments

Comments
 (0)