@@ -405,25 +405,39 @@ def fast_pos_embed_interpolate(self,
405405 dh = h_idxs - h_floor
406406 dw = w_idxs - w_floor
407407
408- w00 = ((1 - dh )[:, None ] * (1 - dw )[None , :]).reshape (- 1 )
409- w01 = ((1 - dh )[:, None ] * dw [None , :]).reshape (- 1 )
410- w10 = (dh [:, None ] * (1 - dw )[None , :]).reshape (- 1 )
411- w11 = (dh [:, None ] * dw [None , :]).reshape (- 1 )
412-
413- idx00 = (h_floor [:, None ] * num_grid_per_side +
414- w_floor [None , :]).reshape (- 1 )
415- idx01 = (h_floor [:, None ] * num_grid_per_side +
416- w_ceil [None , :]).reshape (- 1 )
417- idx10 = (h_ceil [:, None ] * num_grid_per_side +
418- w_floor [None , :]).reshape (- 1 )
419- idx11 = (h_ceil [:, None ] * num_grid_per_side +
420- w_ceil [None , :]).reshape (- 1 )
421-
422- indices = torch .stack ([idx00 , idx01 , idx10 , idx11 ], dim = 0 )
408+ # Create meshgrid view for all h, w vars
409+ dh_grid , dw_grid = torch .meshgrid (dh , dw , indexing = 'ij' )
410+ h_floor_grid , w_floor_grid = torch .meshgrid (h_floor ,
411+ w_floor ,
412+ indexing = 'ij' )
413+ h_ceil_grid , w_ceil_grid = torch .meshgrid (h_ceil ,
414+ w_ceil ,
415+ indexing = 'ij' )
416+ h_floor_grid_idx = h_floor_grid * num_grid_per_side
417+ h_ceil_grid_idx = h_ceil_grid * num_grid_per_side
418+
419+ # original computation of weights
420+ # w00 = (1 - dh_grid) * (1 - dw_grid)
421+ # w01 = (1 - dh_grid) * dw_grid
422+ # w10 = dh_grid * (1 - dw_grid)
423+ # w11 = dh_grid * dw_grid
424+ # we reuse w11 here to avoid duplicate
425+ # dh_grid * dw_grid computation
426+ w11 = dh_grid * dw_grid
427+ w10 = dh_grid - w11
428+ w01 = dw_grid - w11
429+ w00 = 1 - dh_grid - dw_grid + w11
430+
431+ idx00 = h_floor_grid_idx + w_floor_grid
432+ idx01 = h_floor_grid_idx + w_ceil_grid
433+ idx10 = h_ceil_grid_idx + w_floor_grid
434+ idx11 = h_ceil_grid_idx + w_ceil_grid
435+
436+ indices = torch .stack ([idx00 , idx01 , idx10 , idx11 ],
437+ dim = 0 ).reshape (4 , - 1 )
423438 weights = torch .stack ([w00 , w01 , w10 , w11 ],
424- dim = 0 ).to (dtype = self .dtype ,
425- device = self .device )
426- weights = weights .unsqueeze (- 1 )
439+ dim = 0 ).reshape (4 , - 1 , 1 )
440+ weights = weights .to (dtype = self .dtype , device = self .device )
427441
428442 embeds = self .pos_embed (indices )
429443 weighted_embeds = embeds * weights
0 commit comments