Skip to content

Commit 17a57b2

Browse files
committed
address review comment: #11428 (comment)
1 parent 2f9efa9 commit 17a57b2

File tree

1 file changed

+45
-46
lines changed

1 file changed

+45
-46
lines changed

src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py

Lines changed: 45 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -236,15 +236,42 @@ def forward(
236236
post_patch_width = width // p
237237
original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
238238

239+
if indices_latents is None:
240+
indices_latents = torch.arange(0, num_frames).unsqueeze(0).expand(batch_size, -1)
241+
242+
hidden_states = self.x_embedder(hidden_states)
243+
image_rotary_emb = self.rope(
244+
frame_indices=indices_latents, height=height, width=width, device=hidden_states.device
245+
)
246+
247+
latents_clean, latents_history_2x, latents_history_4x = self.clean_x_embedder(
248+
latents_clean, latents_history_2x, latents_history_4x
249+
)
250+
251+
if latents_clean is not None and indices_latents_clean is not None:
252+
image_rotary_emb_clean = self.rope(
253+
frame_indices=indices_latents_clean, height=height, width=width, device=hidden_states.device
254+
)
255+
if latents_history_2x is not None and indices_latents_history_2x is not None:
256+
image_rotary_emb_history_2x = self.rope(
257+
frame_indices=indices_latents_history_2x, height=height, width=width, device=hidden_states.device
258+
)
259+
if latents_history_4x is not None and indices_latents_history_4x is not None:
260+
image_rotary_emb_history_4x = self.rope(
261+
frame_indices=indices_latents_history_4x, height=height, width=width, device=hidden_states.device
262+
)
263+
239264
hidden_states, image_rotary_emb = self._pack_history_states(
240265
hidden_states,
241-
indices_latents,
242266
latents_clean,
243267
latents_history_2x,
244268
latents_history_4x,
245-
indices_latents_clean,
246-
indices_latents_history_2x,
247-
indices_latents_history_4x,
269+
image_rotary_emb,
270+
image_rotary_emb_clean,
271+
image_rotary_emb_history_2x,
272+
image_rotary_emb_history_4x,
273+
post_patch_height,
274+
post_patch_width,
248275
)
249276

250277
temb, _ = self.time_text_embed(timestep, pooled_projections, guidance)
@@ -318,76 +345,48 @@ def forward(
318345
def _pack_history_states(
319346
self,
320347
hidden_states: torch.Tensor,
321-
indices_latents: torch.Tensor,
322348
latents_clean: Optional[torch.Tensor] = None,
323349
latents_history_2x: Optional[torch.Tensor] = None,
324350
latents_history_4x: Optional[torch.Tensor] = None,
325-
indices_latents_clean: Optional[torch.Tensor] = None,
326-
indices_latents_history_2x: Optional[torch.Tensor] = None,
327-
indices_latents_history_4x: Optional[torch.Tensor] = None,
351+
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] = None,
352+
image_rotary_emb_clean: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
353+
image_rotary_emb_history_2x: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
354+
image_rotary_emb_history_4x: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
355+
height: int = None,
356+
width: int = None,
328357
):
329-
batch_size, num_channels, num_frames, height, width = hidden_states.shape
330-
if indices_latents is None:
331-
indices_latents = torch.arange(0, num_frames).unsqueeze(0).expand(batch_size, -1)
332-
333-
hidden_states = self.x_embedder(hidden_states)
334-
image_rotary_emb = self.rope(
335-
frame_indices=indices_latents, height=height, width=width, device=hidden_states.device
336-
)
337358
image_rotary_emb = list(image_rotary_emb) # convert tuple to list for in-place modification
338-
pph, ppw = height // self.config.patch_size, width // self.config.patch_size
339359

340-
latents_clean, latents_history_2x, latents_history_4x = self.clean_x_embedder(
341-
latents_clean, latents_history_2x, latents_history_4x
342-
)
343-
344-
if latents_clean is not None:
360+
if latents_clean is not None and image_rotary_emb_clean is not None:
345361
hidden_states = torch.cat([latents_clean, hidden_states], dim=1)
346-
347-
image_rotary_emb_clean = self.rope(
348-
frame_indices=indices_latents_clean, height=height, width=width, device=latents_clean.device
349-
)
350362
image_rotary_emb[0] = torch.cat([image_rotary_emb_clean[0], image_rotary_emb[0]], dim=0)
351363
image_rotary_emb[1] = torch.cat([image_rotary_emb_clean[1], image_rotary_emb[1]], dim=0)
352364

353-
if latents_history_2x is not None and indices_latents_history_2x is not None:
365+
if latents_history_2x is not None and image_rotary_emb_history_2x is not None:
354366
hidden_states = torch.cat([latents_history_2x, hidden_states], dim=1)
355-
356-
image_rotary_emb_history_2x = self.rope(
357-
frame_indices=indices_latents_history_2x, height=height, width=width, device=latents_history_2x.device
358-
)
359-
image_rotary_emb_history_2x = self._pad_rotary_emb(
360-
image_rotary_emb_history_2x, indices_latents_history_2x.size(0), pph, ppw, (2, 2, 2)
361-
)
367+
image_rotary_emb_history_2x = self._pad_rotary_emb(image_rotary_emb_history_2x, height, width, (2, 2, 2))
362368
image_rotary_emb[0] = torch.cat([image_rotary_emb_history_2x[0], image_rotary_emb[0]], dim=0)
363369
image_rotary_emb[1] = torch.cat([image_rotary_emb_history_2x[1], image_rotary_emb[1]], dim=0)
364370

365-
if latents_history_4x is not None and indices_latents_history_4x is not None:
371+
if latents_history_4x is not None and image_rotary_emb_history_4x is not None:
366372
hidden_states = torch.cat([latents_history_4x, hidden_states], dim=1)
367-
368-
image_rotary_emb_history_4x = self.rope(
369-
frame_indices=indices_latents_history_4x, height=height, width=width, device=latents_history_4x.device
370-
)
371-
image_rotary_emb_history_4x = self._pad_rotary_emb(
372-
image_rotary_emb_history_4x, indices_latents_history_4x.size(0), pph, ppw, (4, 4, 4)
373-
)
373+
image_rotary_emb_history_4x = self._pad_rotary_emb(image_rotary_emb_history_4x, height, width, (4, 4, 4))
374374
image_rotary_emb[0] = torch.cat([image_rotary_emb_history_4x[0], image_rotary_emb[0]], dim=0)
375375
image_rotary_emb[1] = torch.cat([image_rotary_emb_history_4x[1], image_rotary_emb[1]], dim=0)
376376

377-
return hidden_states, image_rotary_emb
377+
return hidden_states, tuple(image_rotary_emb)
378378

379379
def _pad_rotary_emb(
380380
self,
381381
image_rotary_emb: Tuple[torch.Tensor],
382-
num_frames: int,
383382
height: int,
384383
width: int,
385384
kernel_size: Tuple[int, int, int],
386385
):
387386
# freqs_cos, freqs_sin have shape [W * H * T, D / 2], where D is attention head dim
388387
freqs_cos, freqs_sin = image_rotary_emb
389-
freqs_cos = freqs_cos.unsqueeze(0).permute(0, 2, 1).unflatten(2, (num_frames, height, width))
390-
freqs_sin = freqs_sin.unsqueeze(0).permute(0, 2, 1).unflatten(2, (num_frames, height, width))
388+
freqs_cos = freqs_cos.unsqueeze(0).permute(0, 2, 1).unflatten(2, (-1, height, width))
389+
freqs_sin = freqs_sin.unsqueeze(0).permute(0, 2, 1).unflatten(2, (-1, height, width))
391390
freqs_cos = _pad_for_3d_conv(freqs_cos, kernel_size)
392391
freqs_sin = _pad_for_3d_conv(freqs_sin, kernel_size)
393392
freqs_cos = _center_down_sample_3d(freqs_cos, kernel_size)

0 commit comments

Comments
 (0)