Skip to content

Commit 74b5fed

Browse files
Aishwarya0811DN6github-actions[bot]
authored
Fix MPS compatibility in get_1d_sincos_pos_embed_from_grid huggingface#12432 (huggingface#12449)
* Fix MPS compatibility in get_1d_sincos_pos_embed_from_grid huggingface#12432 * Fix trailing whitespace in docstring * Apply style fixes --------- Co-authored-by: Dhruv Nair <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 85eb505 commit 74b5fed

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

src/diffusers/models/embeddings.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,13 +319,17 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
319319
return emb
320320

321321

322-
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False):
322+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False, dtype=None):
323323
"""
324324
This function generates 1D positional embeddings from a grid.
325325
326326
Args:
327327
embed_dim (`int`): The embedding dimension `D`
328328
pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
329+
output_type (`str`, *optional*, defaults to `"np"`): Output type. Use `"pt"` for PyTorch tensors.
330+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip sine and cosine embeddings.
331+
dtype (`torch.dtype`, *optional*): Data type for frequency calculations. If `None`, defaults to
332+
`torch.float32` on MPS devices (which don't support `torch.float64`) and `torch.float64` on other devices.
329333
330334
Returns:
331335
`torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
@@ -341,7 +345,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin
341345
if embed_dim % 2 != 0:
342346
raise ValueError("embed_dim must be divisible by 2")
343347

344-
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
348+
# Auto-detect appropriate dtype if not specified
349+
if dtype is None:
350+
dtype = torch.float32 if pos.device.type == "mps" else torch.float64
351+
352+
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=dtype)
345353
omega /= embed_dim / 2.0
346354
omega = 1.0 / 10000**omega # (D/2,)
347355

0 commit comments

Comments
 (0)