@@ -319,12 +319,16 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
319319 assert embed_dim % 4 == 0
320320
321321 # use half of dimensions to encode grid_h
322- emb_h = get_1d_rotary_pos_embed (embed_dim // 2 , grid [0 ].reshape (- 1 ), use_real = use_real ) # (H*W, D/4)
323- emb_w = get_1d_rotary_pos_embed (embed_dim // 2 , grid [1 ].reshape (- 1 ), use_real = use_real ) # (H*W, D/4)
322+ emb_h = get_1d_rotary_pos_embed (
323+ embed_dim // 2 , grid [0 ].reshape (- 1 ), use_real = use_real
324+ ) # (H*W, D/2) if use_real else (H*W, D/4)
325+ emb_w = get_1d_rotary_pos_embed (
326+ embed_dim // 2 , grid [1 ].reshape (- 1 ), use_real = use_real
327+ ) # (H*W, D/2) if use_real else (H*W, D/4)
324328
325329 if use_real :
326- cos = torch .cat ([emb_h [0 ], emb_w [0 ]], dim = 1 ) # (H*W, D/2 )
327- sin = torch .cat ([emb_h [1 ], emb_w [1 ]], dim = 1 ) # (H*W, D/2 )
330+ cos = torch .cat ([emb_h [0 ], emb_w [0 ]], dim = 1 ) # (H*W, D)
331+ sin = torch .cat ([emb_h [1 ], emb_w [1 ]], dim = 1 ) # (H*W, D)
328332 return cos , sin
329333 else :
330334 emb = torch .cat ([emb_h , emb_w ], dim = 1 ) # (H*W, D/2)
@@ -371,6 +375,8 @@ def get_1d_rotary_pos_embed(
371375 Returns:
372376 `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
373377 """
378+ assert dim % 2 == 0
379+
374380 if isinstance (pos , int ):
375381 pos = np .arange (pos )
376382 theta = theta * ntk_factor
0 commit comments