Skip to content

Commit 697ecec

Browse files
DTG2005a-r-r-o-w
andauthored
Update src/diffusers/models/embeddings.py
Co-authored-by: Aryan <[email protected]>
1 parent 493c49b commit 697ecec

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

src/diffusers/models/embeddings.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,13 +199,14 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
199199

200200
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
201201
"""
202-
This function generates 1D positional embeddings from sin and cos values.
202+
This function generates 1D positional embeddings from a grid.
203203
204204
Args:
205-
embed_dim(`int`): output dimension for each position
206-
pos(`numpy.ndarray(dtype=float)`): tensor in shape (M, 1)
207-
Output:
208-
`numpy.ndarray(dtype=float)`: tensor in shape (M, D)
205+
embed_dim (`int`): The embedding dimension `D`
206+
pos (`numpy.ndarray`): 1D tensor of positions with shape `(M,)`
207+
208+
Returns:
209+
`numpy.ndarray`: Sinusoidal positional embeddings of shape `(M, D)`.
209210
"""
210211
if embed_dim % 2 != 0:
211212
raise ValueError("embed_dim must be divisible by 2")

0 commit comments

Comments
 (0)