@@ -68,6 +68,21 @@ def __init__(
6868        self .height , self .width  =  height  //  patch_size , width  //  patch_size 
6969        self .base_size  =  height  //  patch_size 
7070
71+     def  pe_selection_index_based_on_dim (self , h , w ):
72+         # select subset of positional embedding based on H, W, where H, W is size of latent 
73+         # PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected 
74+         # because original input are in flattened format, we have to flatten this 2d grid as well. 
75+         h_p , w_p  =  h  //  self .patch_size , w  //  self .patch_size 
76+         original_pe_indexes  =  torch .arange (self .pos_embed .shape [1 ])
77+         h_max , w_max  =  int (self .pos_embed_max_size ** 0.5 ), int (self .pos_embed_max_size ** 0.5 )
78+         original_pe_indexes  =  original_pe_indexes .view (h_max , w_max )
79+         starth  =  h_max  //  2  -  h_p  //  2 
80+         endh  =  starth  +  h_p 
81+         startw  =  w_max  //  2  -  w_p  //  2 
82+         endw  =  startw  +  w_p 
83+         original_pe_indexes  =  original_pe_indexes [starth :endh , startw :endw ]
84+         return  original_pe_indexes .flatten ()
85+ 
7186    def  forward (self , latent ):
7287        batch_size , num_channels , height , width  =  latent .size ()
7388        latent  =  latent .view (
@@ -80,7 +95,8 @@ def forward(self, latent):
8095        )
8196        latent  =  latent .permute (0 , 2 , 4 , 1 , 3 , 5 ).flatten (- 3 ).flatten (1 , 2 )
8297        latent  =  self .proj (latent )
83-         return  latent  +  self .pos_embed 
98+         pe_index  =  self .pe_selection_index_based_on_dim (height , width )
99+         return  latent  +  self .pos_embed [:, pe_index ]
84100
85101
86102# Taken from the original Aura flow inference code. 
0 commit comments