Skip to content

Commit 493c49b

Browse files
DTG2005a-r-r-o-w
andauthored
Update src/diffusers/models/embeddings.py
Co-authored-by: Aryan <[email protected]>
1 parent ed2cf8c commit 493c49b

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

src/diffusers/models/embeddings.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,15 @@ def get_2d_sincos_pos_embed(
176176

177177

178178
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
179-
"""
180-
This function generates 2D positional embeddings from a grid.
179+
r"""
180+
This function generates 2D sinusoidal positional embeddings from a grid.
181181
182182
Args:
183-
embed_dim (`int`): output dimension for each position
184-
grid (`np.ndarray`): grid of positions
185-
Output:
186-
`np.ndarray`: tensor in shape (grid_size*grid_size, embed_dim)
183+
embed_dim (`int`): The embedding dimension.
184+
grid (`np.ndarray`): Grid of positions with shape `(H * W,)`.
185+
186+
Returns:
187+
`np.ndarray`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
187188
"""
188189
if embed_dim % 2 != 0:
189190
raise ValueError("embed_dim must be divisible by 2")

0 commit comments

Comments
 (0)