@@ -319,13 +319,17 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
319319 return emb
320320
321321
322- def get_1d_sincos_pos_embed_from_grid (embed_dim , pos , output_type = "np" , flip_sin_to_cos = False ):
322+ def get_1d_sincos_pos_embed_from_grid (embed_dim , pos , output_type = "np" , flip_sin_to_cos = False , dtype = None ):
323323 """
324324 This function generates 1D positional embeddings from a grid.
325325
326326 Args:
327327 embed_dim (`int`): The embedding dimension `D`
328328 pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
329+ output_type (`str`, *optional*, defaults to `"np"`): Output type. Use `"pt"` for PyTorch tensors.
330+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip sine and cosine embeddings.
331+ dtype (`torch.dtype`, *optional*): Data type for frequency calculations. If `None`, defaults to
332+ `torch.float32` on MPS devices (which don't support `torch.float64`) and `torch.float64` on other devices.
329333
330334 Returns:
331335 `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
@@ -341,7 +345,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin
341345 if embed_dim % 2 != 0 :
342346 raise ValueError ("embed_dim must be divisible by 2" )
343347
344- omega = torch .arange (embed_dim // 2 , device = pos .device , dtype = torch .float64 )
348+ # Auto-detect appropriate dtype if not specified
349+ if dtype is None :
350+ dtype = torch .float32 if pos .device .type == "mps" else torch .float64
351+
352+ omega = torch .arange (embed_dim // 2 , device = pos .device , dtype = dtype )
345353 omega /= embed_dim / 2.0
346354 omega = 1.0 / 10000 ** omega # (D/2,)
347355
0 commit comments