Skip to content

Commit 4f00b4e

Browse files
committed
Use torch in get_2d_rotary_pos_embed
1 parent 6131a93 commit 4f00b4e

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/diffusers/models/embeddings.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -733,10 +733,11 @@ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
733733
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
734734
"""
735735
start, stop = crops_coords
736-
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
737-
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
738-
grid = np.meshgrid(grid_w, grid_h) # here w goes first
739-
grid = np.stack(grid, axis=0) # [2, W, H]
736+
# scale end by (steps−1)/steps matches np.linspace(..., endpoint=False)
737+
grid_h = torch.linspace(start[0], stop[0] * (grid_size[0] - 1) / grid_size[0], grid_size[0], dtype=torch.float32)
738+
grid_w = torch.linspace(start[1], stop[1] * (grid_size[1] - 1) / grid_size[1], grid_size[1], dtype=torch.float32)
739+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
740+
grid = torch.stack(grid, dim=0) # [2, W, H]
740741

741742
grid = grid.reshape([2, 1, *grid.shape[1:]])
742743
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)

0 commit comments

Comments
 (0)