Skip to content

Commit c5bd771

Browse files
committed
deprecate
1 parent 1adf5f0 commit c5bd771

File tree

3 files changed

+217
-11
lines changed

3 files changed

+217
-11
lines changed

src/diffusers/models/embeddings.py

Lines changed: 215 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def get_3d_sincos_pos_embed(
8585
spatial_interpolation_scale: float = 1.0,
8686
temporal_interpolation_scale: float = 1.0,
8787
device: Optional[torch.device] = None,
88+
output_type: str = "np",
8889
) -> torch.Tensor:
8990
r"""
9091
Creates 3D sinusoidal positional embeddings.
@@ -107,6 +108,20 @@ def get_3d_sincos_pos_embed(
107108
The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
108109
embed_dim]`.
109110
"""
111+
if output_type == "np":
112+
deprecation_message = (
113+
"`get_3d_sincos_pos_embed` uses `torch` and supports `device`."
114+
" `from_numpy` is no longer required."
115+
" Pass `output_type='pt' to use the new version now."
116+
)
117+
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
118+
return get_3d_sincos_pos_embed_np(
119+
embed_dim=embed_dim,
120+
spatial_size=spatial_size,
121+
temporal_size=temporal_size,
122+
spatial_interpolation_scale=spatial_interpolation_scale,
123+
temporal_interpolation_scale=temporal_interpolation_scale,
124+
)
110125
if embed_dim % 4 != 0:
111126
raise ValueError("`embed_dim` must be divisible by 4")
112127
if isinstance(spatial_size, int):
@@ -122,11 +137,11 @@ def get_3d_sincos_pos_embed(
122137
grid = torch.stack(grid, dim=0)
123138

124139
grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
125-
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
140+
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid, output_type="pt")
126141

127142
# 2. Temporal
128143
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32) / temporal_interpolation_scale
129-
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
144+
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t, output_type="pt")
130145

131146
# 3. Concat
132147
pos_embed_spatial = pos_embed_spatial[None, :, :]
@@ -141,6 +156,66 @@ def get_3d_sincos_pos_embed(
141156
return pos_embed
142157

143158

159+
def get_3d_sincos_pos_embed_np(
160+
embed_dim: int,
161+
spatial_size: Union[int, Tuple[int, int]],
162+
temporal_size: int,
163+
spatial_interpolation_scale: float = 1.0,
164+
temporal_interpolation_scale: float = 1.0,
165+
) -> np.ndarray:
166+
r"""
167+
Creates 3D sinusoidal positional embeddings.
168+
169+
Args:
170+
embed_dim (`int`):
171+
The embedding dimension of inputs. It must be divisible by 16.
172+
spatial_size (`int` or `Tuple[int, int]`):
173+
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
174+
spatial dimensions (height and width).
175+
temporal_size (`int`):
176+
The temporal dimension of postional embeddings (number of frames).
177+
spatial_interpolation_scale (`float`, defaults to 1.0):
178+
Scale factor for spatial grid interpolation.
179+
temporal_interpolation_scale (`float`, defaults to 1.0):
180+
Scale factor for temporal grid interpolation.
181+
182+
Returns:
183+
`np.ndarray`:
184+
The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
185+
embed_dim]`.
186+
"""
187+
if embed_dim % 4 != 0:
188+
raise ValueError("`embed_dim` must be divisible by 4")
189+
if isinstance(spatial_size, int):
190+
spatial_size = (spatial_size, spatial_size)
191+
192+
embed_dim_spatial = 3 * embed_dim // 4
193+
embed_dim_temporal = embed_dim // 4
194+
195+
# 1. Spatial
196+
grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale
197+
grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale
198+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
199+
grid = np.stack(grid, axis=0)
200+
201+
grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
202+
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
203+
204+
# 2. Temporal
205+
grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale
206+
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
207+
208+
# 3. Concat
209+
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
210+
pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3]
211+
212+
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
213+
pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4]
214+
215+
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D]
216+
return pos_embed
217+
218+
144219
def get_2d_sincos_pos_embed(
145220
embed_dim,
146221
grid_size,
@@ -149,6 +224,7 @@ def get_2d_sincos_pos_embed(
149224
interpolation_scale=1.0,
150225
base_size=16,
151226
device: Optional[torch.device] = None,
227+
output_type: str = "np",
152228
):
153229
"""
154230
Creates 2D sinusoidal positional embeddings.
@@ -170,6 +246,21 @@ def get_2d_sincos_pos_embed(
170246
Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
171247
embed_dim]` if using cls_token
172248
"""
249+
if output_type == "np":
250+
deprecation_message = (
251+
"`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
252+
" `from_numpy` is no longer required."
253+
" Pass `output_type='pt' to use the new version now."
254+
)
255+
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
256+
return get_2d_sincos_pos_embed_np(
257+
embed_dim=embed_dim,
258+
grid_size=grid_size,
259+
cls_token=cls_token,
260+
extra_tokens=extra_tokens,
261+
interpolation_scale=interpolation_scale,
262+
base_size=base_size,
263+
)
173264
if isinstance(grid_size, int):
174265
grid_size = (grid_size, grid_size)
175266

@@ -187,13 +278,13 @@ def get_2d_sincos_pos_embed(
187278
grid = torch.stack(grid, dim=0)
188279

189280
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
190-
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
281+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type=output_type)
191282
if cls_token and extra_tokens > 0:
192283
pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0)
193284
return pos_embed
194285

195286

196-
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
287+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
197288
r"""
198289
This function generates 2D sinusoidal positional embeddings from a grid.
199290
@@ -204,18 +295,29 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
204295
Returns:
205296
`torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
206297
"""
298+
if output_type == "np":
299+
deprecation_message = (
300+
"`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
301+
" `from_numpy` is no longer required."
302+
" Pass `output_type='pt' to use the new version now."
303+
)
304+
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
305+
return get_2d_sincos_pos_embed_from_grid_np(
306+
embed_dim=embed_dim,
307+
grid=grid,
308+
)
207309
if embed_dim % 2 != 0:
208310
raise ValueError("embed_dim must be divisible by 2")
209311

210312
# use half of dimensions to encode grid_h
211-
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
212-
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
313+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], output_type=output_type) # (H*W, D/2)
314+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], output_type=output_type) # (H*W, D/2)
213315

214316
emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D)
215317
return emb
216318

217319

218-
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
320+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
219321
"""
220322
This function generates 1D positional embeddings from a grid.
221323
@@ -226,6 +328,14 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
226328
Returns:
227329
`torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
228330
"""
331+
if output_type == "np":
332+
deprecation_message = (
333+
"`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
334+
" `from_numpy` is no longer required."
335+
" Pass `output_type='pt' to use the new version now."
336+
)
337+
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
338+
return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
229339
if embed_dim % 2 != 0:
230340
raise ValueError("embed_dim must be divisible by 2")
231341

@@ -243,6 +353,94 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
243353
return emb
244354

245355

356+
def get_2d_sincos_pos_embed_np(
357+
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
358+
):
359+
"""
360+
Creates 2D sinusoidal positional embeddings.
361+
362+
Args:
363+
embed_dim (`int`):
364+
The embedding dimension.
365+
grid_size (`int`):
366+
The size of the grid height and width.
367+
cls_token (`bool`, defaults to `False`):
368+
Whether or not to add a classification token.
369+
extra_tokens (`int`, defaults to `0`):
370+
The number of extra tokens to add.
371+
interpolation_scale (`float`, defaults to `1.0`):
372+
The scale of the interpolation.
373+
374+
Returns:
375+
pos_embed (`np.ndarray`):
376+
Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
377+
embed_dim]` if using cls_token
378+
"""
379+
if isinstance(grid_size, int):
380+
grid_size = (grid_size, grid_size)
381+
382+
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
383+
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
384+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
385+
grid = np.stack(grid, axis=0)
386+
387+
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
388+
pos_embed = get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid)
389+
if cls_token and extra_tokens > 0:
390+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
391+
return pos_embed
392+
393+
394+
def get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid):
395+
r"""
396+
This function generates 2D sinusoidal positional embeddings from a grid.
397+
398+
Args:
399+
embed_dim (`int`): The embedding dimension.
400+
grid (`np.ndarray`): Grid of positions with shape `(H * W,)`.
401+
402+
Returns:
403+
`np.ndarray`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
404+
"""
405+
if embed_dim % 2 != 0:
406+
raise ValueError("embed_dim must be divisible by 2")
407+
408+
# use half of dimensions to encode grid_h
409+
emb_h = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[0]) # (H*W, D/2)
410+
emb_w = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[1]) # (H*W, D/2)
411+
412+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
413+
return emb
414+
415+
416+
def get_1d_sincos_pos_embed_from_grid_np(embed_dim, pos):
417+
"""
418+
This function generates 1D positional embeddings from a grid.
419+
420+
Args:
421+
embed_dim (`int`): The embedding dimension `D`
422+
pos (`numpy.ndarray`): 1D tensor of positions with shape `(M,)`
423+
424+
Returns:
425+
`numpy.ndarray`: Sinusoidal positional embeddings of shape `(M, D)`.
426+
"""
427+
if embed_dim % 2 != 0:
428+
raise ValueError("embed_dim must be divisible by 2")
429+
430+
omega = np.arange(embed_dim // 2, dtype=np.float64)
431+
omega /= embed_dim / 2.0
432+
omega = 1.0 / 10000**omega # (D/2,)
433+
434+
pos = pos.reshape(-1) # (M,)
435+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
436+
437+
emb_sin = np.sin(out) # (M, D/2)
438+
emb_cos = np.cos(out) # (M, D/2)
439+
440+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
441+
return emb
442+
443+
246444
class PatchEmbed(nn.Module):
247445
"""
248446
2D Image to Patch Embedding with support for SD3 cropping.
@@ -305,7 +503,11 @@ def __init__(
305503
self.pos_embed = None
306504
elif pos_embed_type == "sincos":
307505
pos_embed = get_2d_sincos_pos_embed(
308-
embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
506+
embed_dim,
507+
grid_size,
508+
base_size=self.base_size,
509+
interpolation_scale=self.interpolation_scale,
510+
output_type="pt",
309511
)
310512
persistent = True if pos_embed_max_size else False
311513
self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=persistent)
@@ -359,6 +561,7 @@ def forward(self, latent):
359561
base_size=self.base_size,
360562
interpolation_scale=self.interpolation_scale,
361563
device=latent.device,
564+
output_type="pt",
362565
)
363566
pos_embed = pos_embed.float().unsqueeze(0)
364567
else:
@@ -486,6 +689,7 @@ def _get_positional_embeddings(
486689
self.spatial_interpolation_scale,
487690
self.temporal_interpolation_scale,
488691
device=device,
692+
output_type="pt",
489693
)
490694
pos_embedding = pos_embedding.flatten(0, 1)
491695
joint_pos_embedding = torch.zeros(
@@ -575,7 +779,9 @@ def __init__(
575779
# Linear projection for text embeddings
576780
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
577781

578-
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size)
782+
pos_embed = get_2d_sincos_pos_embed(
783+
hidden_size, pos_embed_max_size, base_size=pos_embed_max_size, output_type="pt"
784+
)
579785
pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
580786
self.register_buffer("pos_embed", pos_embed.float(), persistent=False)
581787

src/diffusers/models/transformers/latte_transformer_3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def __init__(
156156

157157
# define temporal positional embedding
158158
temp_pos_embed = get_1d_sincos_pos_embed_from_grid(
159-
inner_dim, torch.arange(0, video_length).unsqueeze(1)
159+
inner_dim, torch.arange(0, video_length).unsqueeze(1), output_type="pt"
160160
) # 1152 hidden size
161161
self.register_buffer("temp_pos_embed", temp_pos_embed.float().unsqueeze(0), persistent=False)
162162

src/diffusers/pipelines/unidiffuser/modeling_uvit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def __init__(
104104

105105
self.use_pos_embed = use_pos_embed
106106
if self.use_pos_embed:
107-
pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
107+
pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5), output_type="pt")
108108
self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=False)
109109

110110
def forward(self, latent):

0 commit comments

Comments
 (0)