Skip to content

Commit f545353

Browse files
authored
Merge branch 'main' into loading_method_integration
2 parents cb525df + 6324340 commit f545353

File tree

5 files changed

+274
-20
lines changed

5 files changed

+274
-20
lines changed

src/diffusers/models/embeddings.py

Lines changed: 243 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,78 @@ def get_3d_sincos_pos_embed(
8484
temporal_size: int,
8585
spatial_interpolation_scale: float = 1.0,
8686
temporal_interpolation_scale: float = 1.0,
87+
device: Optional[torch.device] = None,
88+
output_type: str = "np",
89+
) -> torch.Tensor:
90+
r"""
91+
Creates 3D sinusoidal positional embeddings.
92+
93+
Args:
94+
embed_dim (`int`):
95+
The embedding dimension of inputs. It must be divisible by 16.
96+
spatial_size (`int` or `Tuple[int, int]`):
97+
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
98+
spatial dimensions (height and width).
99+
temporal_size (`int`):
100+
The temporal dimension of postional embeddings (number of frames).
101+
spatial_interpolation_scale (`float`, defaults to 1.0):
102+
Scale factor for spatial grid interpolation.
103+
temporal_interpolation_scale (`float`, defaults to 1.0):
104+
Scale factor for temporal grid interpolation.
105+
106+
Returns:
107+
`torch.Tensor`:
108+
The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
109+
embed_dim]`.
110+
"""
111+
if output_type == "np":
112+
return _get_3d_sincos_pos_embed_np(
113+
embed_dim=embed_dim,
114+
spatial_size=spatial_size,
115+
temporal_size=temporal_size,
116+
spatial_interpolation_scale=spatial_interpolation_scale,
117+
temporal_interpolation_scale=temporal_interpolation_scale,
118+
)
119+
if embed_dim % 4 != 0:
120+
raise ValueError("`embed_dim` must be divisible by 4")
121+
if isinstance(spatial_size, int):
122+
spatial_size = (spatial_size, spatial_size)
123+
124+
embed_dim_spatial = 3 * embed_dim // 4
125+
embed_dim_temporal = embed_dim // 4
126+
127+
# 1. Spatial
128+
grid_h = torch.arange(spatial_size[1], device=device, dtype=torch.float32) / spatial_interpolation_scale
129+
grid_w = torch.arange(spatial_size[0], device=device, dtype=torch.float32) / spatial_interpolation_scale
130+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
131+
grid = torch.stack(grid, dim=0)
132+
133+
grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
134+
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid, output_type="pt")
135+
136+
# 2. Temporal
137+
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32) / temporal_interpolation_scale
138+
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t, output_type="pt")
139+
140+
# 3. Concat
141+
pos_embed_spatial = pos_embed_spatial[None, :, :]
142+
pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3]
143+
144+
pos_embed_temporal = pos_embed_temporal[:, None, :]
145+
pos_embed_temporal = pos_embed_temporal.repeat_interleave(
146+
spatial_size[0] * spatial_size[1], dim=1
147+
) # [T, H*W, D // 4]
148+
149+
pos_embed = torch.concat([pos_embed_temporal, pos_embed_spatial], dim=-1) # [T, H*W, D]
150+
return pos_embed
151+
152+
153+
def _get_3d_sincos_pos_embed_np(
154+
embed_dim: int,
155+
spatial_size: Union[int, Tuple[int, int]],
156+
temporal_size: int,
157+
spatial_interpolation_scale: float = 1.0,
158+
temporal_interpolation_scale: float = 1.0,
87159
) -> np.ndarray:
88160
r"""
89161
Creates 3D sinusoidal positional embeddings.
@@ -106,6 +178,12 @@ def get_3d_sincos_pos_embed(
106178
The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
107179
embed_dim]`.
108180
"""
181+
deprecation_message = (
182+
"`get_3d_sincos_pos_embed` uses `torch` and supports `device`."
183+
" `from_numpy` is no longer required."
184+
" Pass `output_type='pt' to use the new version now."
185+
)
186+
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
109187
if embed_dim % 4 != 0:
110188
raise ValueError("`embed_dim` must be divisible by 4")
111189
if isinstance(spatial_size, int):
@@ -139,6 +217,143 @@ def get_3d_sincos_pos_embed(
139217

140218

141219
def get_2d_sincos_pos_embed(
220+
embed_dim,
221+
grid_size,
222+
cls_token=False,
223+
extra_tokens=0,
224+
interpolation_scale=1.0,
225+
base_size=16,
226+
device: Optional[torch.device] = None,
227+
output_type: str = "np",
228+
):
229+
"""
230+
Creates 2D sinusoidal positional embeddings.
231+
232+
Args:
233+
embed_dim (`int`):
234+
The embedding dimension.
235+
grid_size (`int`):
236+
The size of the grid height and width.
237+
cls_token (`bool`, defaults to `False`):
238+
Whether or not to add a classification token.
239+
extra_tokens (`int`, defaults to `0`):
240+
The number of extra tokens to add.
241+
interpolation_scale (`float`, defaults to `1.0`):
242+
The scale of the interpolation.
243+
244+
Returns:
245+
pos_embed (`torch.Tensor`):
246+
Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
247+
embed_dim]` if using cls_token
248+
"""
249+
if output_type == "np":
250+
deprecation_message = (
251+
"`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
252+
" `from_numpy` is no longer required."
253+
" Pass `output_type='pt' to use the new version now."
254+
)
255+
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
256+
return get_2d_sincos_pos_embed_np(
257+
embed_dim=embed_dim,
258+
grid_size=grid_size,
259+
cls_token=cls_token,
260+
extra_tokens=extra_tokens,
261+
interpolation_scale=interpolation_scale,
262+
base_size=base_size,
263+
)
264+
if isinstance(grid_size, int):
265+
grid_size = (grid_size, grid_size)
266+
267+
grid_h = (
268+
torch.arange(grid_size[0], device=device, dtype=torch.float32)
269+
/ (grid_size[0] / base_size)
270+
/ interpolation_scale
271+
)
272+
grid_w = (
273+
torch.arange(grid_size[1], device=device, dtype=torch.float32)
274+
/ (grid_size[1] / base_size)
275+
/ interpolation_scale
276+
)
277+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
278+
grid = torch.stack(grid, dim=0)
279+
280+
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
281+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type=output_type)
282+
if cls_token and extra_tokens > 0:
283+
pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0)
284+
return pos_embed
285+
286+
287+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
288+
r"""
289+
This function generates 2D sinusoidal positional embeddings from a grid.
290+
291+
Args:
292+
embed_dim (`int`): The embedding dimension.
293+
grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`.
294+
295+
Returns:
296+
`torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
297+
"""
298+
if output_type == "np":
299+
deprecation_message = (
300+
"`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
301+
" `from_numpy` is no longer required."
302+
" Pass `output_type='pt' to use the new version now."
303+
)
304+
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
305+
return get_2d_sincos_pos_embed_from_grid_np(
306+
embed_dim=embed_dim,
307+
grid=grid,
308+
)
309+
if embed_dim % 2 != 0:
310+
raise ValueError("embed_dim must be divisible by 2")
311+
312+
# use half of dimensions to encode grid_h
313+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], output_type=output_type) # (H*W, D/2)
314+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], output_type=output_type) # (H*W, D/2)
315+
316+
emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D)
317+
return emb
318+
319+
320+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
321+
"""
322+
This function generates 1D positional embeddings from a grid.
323+
324+
Args:
325+
embed_dim (`int`): The embedding dimension `D`
326+
pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
327+
328+
Returns:
329+
`torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
330+
"""
331+
if output_type == "np":
332+
deprecation_message = (
333+
"`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
334+
" `from_numpy` is no longer required."
335+
" Pass `output_type='pt' to use the new version now."
336+
)
337+
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
338+
return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
339+
if embed_dim % 2 != 0:
340+
raise ValueError("embed_dim must be divisible by 2")
341+
342+
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
343+
omega /= embed_dim / 2.0
344+
omega = 1.0 / 10000**omega # (D/2,)
345+
346+
pos = pos.reshape(-1) # (M,)
347+
out = torch.outer(pos, omega) # (M, D/2), outer product
348+
349+
emb_sin = torch.sin(out) # (M, D/2)
350+
emb_cos = torch.cos(out) # (M, D/2)
351+
352+
emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
353+
return emb
354+
355+
356+
def get_2d_sincos_pos_embed_np(
142357
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
143358
):
144359
"""
@@ -170,13 +385,13 @@ def get_2d_sincos_pos_embed(
170385
grid = np.stack(grid, axis=0)
171386

172387
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
173-
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
388+
pos_embed = get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid)
174389
if cls_token and extra_tokens > 0:
175390
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
176391
return pos_embed
177392

178393

179-
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
394+
def get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid):
180395
r"""
181396
This function generates 2D sinusoidal positional embeddings from a grid.
182397
@@ -191,14 +406,14 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
191406
raise ValueError("embed_dim must be divisible by 2")
192407

193408
# use half of dimensions to encode grid_h
194-
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
195-
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
409+
emb_h = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[0]) # (H*W, D/2)
410+
emb_w = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[1]) # (H*W, D/2)
196411

197412
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
198413
return emb
199414

200415

201-
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
416+
def get_1d_sincos_pos_embed_from_grid_np(embed_dim, pos):
202417
"""
203418
This function generates 1D positional embeddings from a grid.
204419
@@ -288,10 +503,14 @@ def __init__(
288503
self.pos_embed = None
289504
elif pos_embed_type == "sincos":
290505
pos_embed = get_2d_sincos_pos_embed(
291-
embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
506+
embed_dim,
507+
grid_size,
508+
base_size=self.base_size,
509+
interpolation_scale=self.interpolation_scale,
510+
output_type="pt",
292511
)
293512
persistent = True if pos_embed_max_size else False
294-
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent)
513+
self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=persistent)
295514
else:
296515
raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
297516

@@ -341,8 +560,10 @@ def forward(self, latent):
341560
grid_size=(height, width),
342561
base_size=self.base_size,
343562
interpolation_scale=self.interpolation_scale,
563+
device=latent.device,
564+
output_type="pt",
344565
)
345-
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
566+
pos_embed = pos_embed.float().unsqueeze(0)
346567
else:
347568
pos_embed = self.pos_embed
348569

@@ -453,7 +674,9 @@ def __init__(
453674
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
454675
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
455676

456-
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
677+
def _get_positional_embeddings(
678+
self, sample_height: int, sample_width: int, sample_frames: int, device: Optional[torch.device] = None
679+
) -> torch.Tensor:
457680
post_patch_height = sample_height // self.patch_size
458681
post_patch_width = sample_width // self.patch_size
459682
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
@@ -465,8 +688,10 @@ def _get_positional_embeddings(self, sample_height: int, sample_width: int, samp
465688
post_time_compression_frames,
466689
self.spatial_interpolation_scale,
467690
self.temporal_interpolation_scale,
691+
device=device,
692+
output_type="pt",
468693
)
469-
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
694+
pos_embedding = pos_embedding.flatten(0, 1)
470695
joint_pos_embedding = torch.zeros(
471696
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
472697
)
@@ -521,8 +746,10 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
521746
or self.sample_width != width
522747
or self.sample_frames != pre_time_compression_frames
523748
):
524-
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
525-
pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
749+
pos_embedding = self._get_positional_embeddings(
750+
height, width, pre_time_compression_frames, device=embeds.device
751+
)
752+
pos_embedding = pos_embedding.to(dtype=embeds.dtype)
526753
else:
527754
pos_embedding = self.pos_embedding
528755

@@ -552,9 +779,11 @@ def __init__(
552779
# Linear projection for text embeddings
553780
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
554781

555-
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size)
782+
pos_embed = get_2d_sincos_pos_embed(
783+
hidden_size, pos_embed_max_size, base_size=pos_embed_max_size, output_type="pt"
784+
)
556785
pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
557-
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False)
786+
self.register_buffer("pos_embed", pos_embed.float(), persistent=False)
558787

559788
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
560789
batch_size, channel, height, width = hidden_states.shape

src/diffusers/models/transformers/latte_transformer_3d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,9 @@ def __init__(
156156

157157
# define temporal positional embedding
158158
temp_pos_embed = get_1d_sincos_pos_embed_from_grid(
159-
inner_dim, torch.arange(0, video_length).unsqueeze(1)
159+
inner_dim, torch.arange(0, video_length).unsqueeze(1), output_type="pt"
160160
) # 1152 hidden size
161-
self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False)
161+
self.register_buffer("temp_pos_embed", temp_pos_embed.float().unsqueeze(0), persistent=False)
162162

163163
self.gradient_checkpointing = False
164164

src/diffusers/models/unets/unet_3d_blocks.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1375,6 +1375,7 @@ def forward(
13751375
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
13761376
temb: Optional[torch.Tensor] = None,
13771377
image_only_indicator: Optional[torch.Tensor] = None,
1378+
upsample_size: Optional[int] = None,
13781379
) -> torch.Tensor:
13791380
for resnet in self.resnets:
13801381
# pop res hidden states
@@ -1415,7 +1416,7 @@ def custom_forward(*inputs):
14151416

14161417
if self.upsamplers is not None:
14171418
for upsampler in self.upsamplers:
1418-
hidden_states = upsampler(hidden_states)
1419+
hidden_states = upsampler(hidden_states, upsample_size)
14191420

14201421
return hidden_states
14211422

@@ -1485,6 +1486,7 @@ def forward(
14851486
temb: Optional[torch.Tensor] = None,
14861487
encoder_hidden_states: Optional[torch.Tensor] = None,
14871488
image_only_indicator: Optional[torch.Tensor] = None,
1489+
upsample_size: Optional[int] = None,
14881490
) -> torch.Tensor:
14891491
for resnet, attn in zip(self.resnets, self.attentions):
14901492
# pop res hidden states
@@ -1533,6 +1535,6 @@ def custom_forward(*inputs):
15331535

15341536
if self.upsamplers is not None:
15351537
for upsampler in self.upsamplers:
1536-
hidden_states = upsampler(hidden_states)
1538+
hidden_states = upsampler(hidden_states, upsample_size)
15371539

15381540
return hidden_states

0 commit comments

Comments
 (0)