Skip to content

Commit 85ea323

Browse files
Update pe_selection_index_based_on_dim
1 parent 680a8ed commit 85ea323

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/diffusers/models/transformers/auraflow_transformer_2d.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,11 @@ def pe_selection_index_based_on_dim(self, h, w):
7878
h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5)
7979
original_pe_indexes = original_pe_indexes.view(h_max, w_max)
8080
starth = h_max // 2 - h_p // 2
81-
endh = starth + h_p
82-
startw = w_max // 2 - w_p // 2
83-
endw = startw + w_p
84-
original_pe_indexes = original_pe_indexes[starth:endh, startw:endw]
85-
return original_pe_indexes.flatten()
81+
startw = w_max // 2 - w_p // 2
82+
narrowed = torch.narrow(original_pe_indexes, 0, starth, h_p)
83+
narrowed = torch.narrow(narrowed, 1, startw, w_p)
84+
85+
return narrowed.flatten()
8686

8787
def forward(self, latent):
8888
batch_size, num_channels, height, width = latent.size()

0 commit comments

Comments
 (0)