@@ -74,15 +74,23 @@ def pe_selection_index_based_on_dim(self, h, w):
7474        # PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected 
7575        # because original input are in flattened format, we have to flatten this 2d grid as well. 
7676        h_p , w_p  =  h  //  self .patch_size , w  //  self .patch_size 
77-         original_pe_indexes  =  torch .arange (self .pos_embed .shape [1 ])
7877        h_max , w_max  =  int (self .pos_embed_max_size ** 0.5 ), int (self .pos_embed_max_size ** 0.5 )
79-         original_pe_indexes  =  original_pe_indexes .view (h_max , w_max )
78+ 
79+         # Calculate the top-left corner indices for the centered patch grid 
8080        starth  =  h_max  //  2  -  h_p  //  2 
81-         startw  =  w_max  //  2  -  w_p  //  2         
82-         narrowed  =  torch .narrow (original_pe_indexes , 0 , starth , h_p )
83-         narrowed  =  torch .narrow (narrowed , 1 , startw , w_p )
84-         
85-         return  narrowed .flatten ()
81+         startw  =  w_max  //  2  -  w_p  //  2 
82+ 
83+         # Generate the row and column indices for the desired patch grid 
84+         rows  =  torch .arange (starth , starth  +  h_p , device = self .pos_embed .device )
85+         cols  =  torch .arange (startw , startw  +  w_p , device = self .pos_embed .device )
86+ 
87+         # Create a 2D grid of indices 
88+         row_indices , col_indices  =  torch .meshgrid (rows , cols , indexing = "ij" )
89+ 
90+         # Convert the 2D grid indices to flattened 1D indices 
91+         selected_indices  =  (row_indices  *  w_max  +  col_indices ).flatten ()
92+ 
93+         return  selected_indices 
8694
8795    def  forward (self , latent ):
8896        batch_size , num_channels , height , width  =  latent .size ()
0 commit comments