Skip to content

Commit 6f0b4d6

Browse files
committed
Use torch in get_3d_sincos_pos_embed
1 parent 1a0cd1c commit 6f0b4d6

File tree

1 file changed

+24
-16
lines changed

1 file changed

+24
-16
lines changed

src/diffusers/models/embeddings.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ def get_3d_sincos_pos_embed(
8484
temporal_size: int,
8585
spatial_interpolation_scale: float = 1.0,
8686
temporal_interpolation_scale: float = 1.0,
87-
) -> np.ndarray:
87+
device: Optional[torch.device] = None,
88+
) -> torch.Tensor:
8889
r"""
8990
Creates 3D sinusoidal positional embeddings.
9091
@@ -102,7 +103,7 @@ def get_3d_sincos_pos_embed(
102103
Scale factor for temporal grid interpolation.
103104
104105
Returns:
105-
`np.ndarray`:
106+
`torch.Tensor`:
106107
The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
107108
embed_dim]`.
108109
"""
@@ -115,26 +116,28 @@ def get_3d_sincos_pos_embed(
115116
embed_dim_temporal = embed_dim // 4
116117

117118
# 1. Spatial
118-
grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale
119-
grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale
120-
grid = np.meshgrid(grid_w, grid_h) # here w goes first
121-
grid = np.stack(grid, axis=0)
119+
grid_h = torch.arange(spatial_size[1], device=device, dtype=torch.float32) / spatial_interpolation_scale
120+
grid_w = torch.arange(spatial_size[0], device=device, dtype=torch.float32) / spatial_interpolation_scale
121+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
122+
grid = torch.stack(grid, dim=0)
122123

123124
grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
124125
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
125126

126127
# 2. Temporal
127-
grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale
128+
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32) / temporal_interpolation_scale
128129
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
129130

130131
# 3. Concat
131-
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
132-
pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3]
132+
pos_embed_spatial = pos_embed_spatial[None, :, :]
133+
pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3]
133134

134-
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
135-
pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4]
135+
pos_embed_temporal = pos_embed_temporal[:, None, :]
136+
pos_embed_temporal = pos_embed_temporal.repeat_interleave(
137+
spatial_size[0] * spatial_size[1], dim=1
138+
) # [T, H*W, D // 4]
136139

137-
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D]
140+
pos_embed = torch.concat([pos_embed_temporal, pos_embed_spatial], dim=-1) # [T, H*W, D]
138141
return pos_embed
139142

140143

@@ -468,7 +471,9 @@ def __init__(
468471
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
469472
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
470473

471-
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
474+
def _get_positional_embeddings(
475+
self, sample_height: int, sample_width: int, sample_frames: int, device: Optional[torch.device] = None
476+
) -> torch.Tensor:
472477
post_patch_height = sample_height // self.patch_size
473478
post_patch_width = sample_width // self.patch_size
474479
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
@@ -480,8 +485,9 @@ def _get_positional_embeddings(self, sample_height: int, sample_width: int, samp
480485
post_time_compression_frames,
481486
self.spatial_interpolation_scale,
482487
self.temporal_interpolation_scale,
488+
device=device,
483489
)
484-
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
490+
pos_embedding = pos_embedding.flatten(0, 1)
485491
joint_pos_embedding = torch.zeros(
486492
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
487493
)
@@ -536,8 +542,10 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
536542
or self.sample_width != width
537543
or self.sample_frames != pre_time_compression_frames
538544
):
539-
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
540-
pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
545+
pos_embedding = self._get_positional_embeddings(
546+
height, width, pre_time_compression_frames, device=embeds.device
547+
)
548+
pos_embedding = pos_embedding.to(dtype=embeds.dtype)
541549
else:
542550
pos_embedding = self.pos_embedding
543551

0 commit comments

Comments
 (0)