Skip to content

Commit 5c6c679

Browse files
Make pe_selection_index_based_on_dim work with torh.compile
1 parent 6bd6b7c commit 5c6c679

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

src/diffusers/models/transformers/auraflow_transformer_2d.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,23 @@ def pe_selection_index_based_on_dim(self, h, w):
7474
# PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected
7575
# because original input are in flattened format, we have to flatten this 2d grid as well.
7676
h_p, w_p = h // self.patch_size, w // self.patch_size
77-
original_pe_indexes = torch.arange(self.pos_embed.shape[1])
7877
h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5)
79-
original_pe_indexes = original_pe_indexes.view(h_max, w_max)
78+
79+
# Calculate the top-left corner indices for the centered patch grid
8080
starth = h_max // 2 - h_p // 2
81-
startw = w_max // 2 - w_p // 2
82-
narrowed = torch.narrow(original_pe_indexes, 0, starth, h_p)
83-
narrowed = torch.narrow(narrowed, 1, startw, w_p)
84-
85-
return narrowed.flatten()
81+
startw = w_max // 2 - w_p // 2
82+
83+
# Generate the row and column indices for the desired patch grid
84+
rows = torch.arange(starth, starth + h_p, device=self.pos_embed.device)
85+
cols = torch.arange(startw, startw + w_p, device=self.pos_embed.device)
86+
87+
# Create a 2D grid of indices
88+
row_indices, col_indices = torch.meshgrid(rows, cols, indexing="ij")
89+
90+
# Convert the 2D grid indices to flattened 1D indices
91+
selected_indices = (row_indices * w_max + col_indices).flatten()
92+
93+
return selected_indices
8694

8795
def forward(self, latent):
8896
batch_size, num_channels, height, width = latent.size()

0 commit comments

Comments
 (0)