Skip to content

Commit 1a0cd1c

Browse files
committed
Use torch in get_2d_sincos_pos_embed
1 parent 6131a93 commit 1a0cd1c

File tree

2 files changed

+36
-21
lines changed

2 files changed

+36
-21
lines changed

src/diffusers/models/embeddings.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,13 @@ def get_3d_sincos_pos_embed(
139139

140140

141141
def get_2d_sincos_pos_embed(
142-
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
142+
embed_dim,
143+
grid_size,
144+
cls_token=False,
145+
extra_tokens=0,
146+
interpolation_scale=1.0,
147+
base_size=16,
148+
device: Optional[torch.device] = None,
143149
):
144150
"""
145151
Creates 2D sinusoidal positional embeddings.
@@ -157,22 +163,30 @@ def get_2d_sincos_pos_embed(
157163
The scale of the interpolation.
158164
159165
Returns:
160-
pos_embed (`np.ndarray`):
166+
pos_embed (`torch.Tensor`):
161167
Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
162168
embed_dim]` if using cls_token
163169
"""
164170
if isinstance(grid_size, int):
165171
grid_size = (grid_size, grid_size)
166172

167-
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
168-
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
169-
grid = np.meshgrid(grid_w, grid_h) # here w goes first
170-
grid = np.stack(grid, axis=0)
173+
grid_h = (
174+
torch.arange(grid_size[0], device=device, dtype=torch.float32)
175+
/ (grid_size[0] / base_size)
176+
/ interpolation_scale
177+
)
178+
grid_w = (
179+
torch.arange(grid_size[1], device=device, dtype=torch.float32)
180+
/ (grid_size[1] / base_size)
181+
/ interpolation_scale
182+
)
183+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
184+
grid = torch.stack(grid, dim=0)
171185

172186
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
173187
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
174188
if cls_token and extra_tokens > 0:
175-
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
189+
pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0)
176190
return pos_embed
177191

178192

@@ -182,10 +196,10 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
182196
183197
Args:
184198
embed_dim (`int`): The embedding dimension.
185-
grid (`np.ndarray`): Grid of positions with shape `(H * W,)`.
199+
grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`.
186200
187201
Returns:
188-
`np.ndarray`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
202+
`torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
189203
"""
190204
if embed_dim % 2 != 0:
191205
raise ValueError("embed_dim must be divisible by 2")
@@ -194,7 +208,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
194208
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
195209
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
196210

197-
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
211+
emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D)
198212
return emb
199213

200214

@@ -204,25 +218,25 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
204218
205219
Args:
206220
embed_dim (`int`): The embedding dimension `D`
207-
pos (`numpy.ndarray`): 1D tensor of positions with shape `(M,)`
221+
pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
208222
209223
Returns:
210-
`numpy.ndarray`: Sinusoidal positional embeddings of shape `(M, D)`.
224+
`torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
211225
"""
212226
if embed_dim % 2 != 0:
213227
raise ValueError("embed_dim must be divisible by 2")
214228

215-
omega = np.arange(embed_dim // 2, dtype=np.float64)
229+
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
216230
omega /= embed_dim / 2.0
217231
omega = 1.0 / 10000**omega # (D/2,)
218232

219233
pos = pos.reshape(-1) # (M,)
220-
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
234+
out = torch.outer(pos, omega) # (M, D/2), outer product
221235

222-
emb_sin = np.sin(out) # (M, D/2)
223-
emb_cos = np.cos(out) # (M, D/2)
236+
emb_sin = torch.sin(out) # (M, D/2)
237+
emb_cos = torch.cos(out) # (M, D/2)
224238

225-
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
239+
emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
226240
return emb
227241

228242

@@ -291,7 +305,7 @@ def __init__(
291305
embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
292306
)
293307
persistent = True if pos_embed_max_size else False
294-
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent)
308+
self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=persistent)
295309
else:
296310
raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
297311

@@ -341,8 +355,9 @@ def forward(self, latent):
341355
grid_size=(height, width),
342356
base_size=self.base_size,
343357
interpolation_scale=self.interpolation_scale,
358+
device=latent.device,
344359
)
345-
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
360+
pos_embed = pos_embed.float().unsqueeze(0)
346361
else:
347362
pos_embed = self.pos_embed
348363

@@ -554,7 +569,7 @@ def __init__(
554569

555570
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size)
556571
pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
557-
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False)
572+
self.register_buffer("pos_embed", pos_embed.float(), persistent=False)
558573

559574
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
560575
batch_size, channel, height, width = hidden_states.shape

src/diffusers/pipelines/unidiffuser/modeling_uvit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def __init__(
105105
self.use_pos_embed = use_pos_embed
106106
if self.use_pos_embed:
107107
pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
108-
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
108+
self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=False)
109109

110110
def forward(self, latent):
111111
latent = self.proj(latent)

0 commit comments

Comments
 (0)