File tree Expand file tree Collapse file tree 1 file changed +7
-6
lines changed Expand file tree Collapse file tree 1 file changed +7
-6
lines changed Original file line number Diff line number Diff line change @@ -176,14 +176,15 @@ def get_2d_sincos_pos_embed(
176176
177177
178178def  get_2d_sincos_pos_embed_from_grid (embed_dim , grid ):
179-     """ 
180-     This function generates 2D positional embeddings from a grid. 
179+     r """
180+     This function generates 2D sinusoidal  positional embeddings from a grid. 
181181
182182    Args: 
183-         embed_dim (`int`): output dimension for each position 
184-         grid (`np.ndarray`): grid of positions 
185-     Output: 
186-         `np.ndarray`: tensor in shape (grid_size*grid_size, embed_dim) 
183+         embed_dim (`int`): The embedding dimension. 
184+         grid (`np.ndarray`): Grid of positions with shape `(H * W,)`. 
185+      
186+     Returns: 
187+         `np.ndarray`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)` 
187188    """ 
188189    if  embed_dim  %  2  !=  0 :
189190        raise  ValueError ("embed_dim must be divisible by 2" )
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments