@@ -68,6 +68,21 @@ def __init__(
6868 self .height , self .width = height // patch_size , width // patch_size
6969 self .base_size = height // patch_size
7070
71+ def pe_selection_index_based_on_dim (self , h , w ):
72+ # select subset of positional embedding based on H, W, where H, W is size of latent
73+ # PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected
74+ # because original input are in flattened format, we have to flatten this 2d grid as well.
75+ h_p , w_p = h // self .patch_size , w // self .patch_size
76+ original_pe_indexes = torch .arange (self .pos_embed .shape [1 ])
77+ h_max , w_max = int (self .pos_embed_max_size ** 0.5 ), int (self .pos_embed_max_size ** 0.5 )
78+ original_pe_indexes = original_pe_indexes .view (h_max , w_max )
79+ starth = h_max // 2 - h_p // 2
80+ endh = starth + h_p
81+ startw = w_max // 2 - w_p // 2
82+ endw = startw + w_p
83+ original_pe_indexes = original_pe_indexes [starth :endh , startw :endw ]
84+ return original_pe_indexes .flatten ()
85+
7186 def forward (self , latent ):
7287 batch_size , num_channels , height , width = latent .size ()
7388 latent = latent .view (
@@ -80,7 +95,8 @@ def forward(self, latent):
8095 )
8196 latent = latent .permute (0 , 2 , 4 , 1 , 3 , 5 ).flatten (- 3 ).flatten (1 , 2 )
8297 latent = self .proj (latent )
83- return latent + self .pos_embed
98+ pe_index = self .pe_selection_index_based_on_dim (height , width )
99+ return latent + self .pos_embed [:, pe_index ]
84100
85101
86102# Taken from the original Aura flow inference code.
0 commit comments