Skip to content

Commit 30a3bb7

Browse files
committed
pack/unpack latents
1 parent c201880 commit 30a3bb7

File tree

2 files changed

+55
-42
lines changed

2 files changed

+55
-42
lines changed

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -116,17 +116,14 @@ def __init__(
116116
self.theta = theta
117117

118118
def forward(
119-
self, hidden_states: torch.Tensor, rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None
119+
self, hidden_states: torch.Tensor, num_frames: int, height: int, width: int, rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None
120120
) -> Tuple[torch.Tensor, torch.Tensor]:
121-
batch_size, num_channels, num_frames, height, width = hidden_states.shape
122-
post_patch_num_frames = num_frames // self.patch_size_t
123-
post_patch_height = height // self.patch_size
124-
post_patch_width = width // self.patch_size
121+
batch_size = hidden_states.size(0)
125122

126123
# Always compute rope in fp32
127-
grid_h = torch.arange(post_patch_height, dtype=torch.float32, device=hidden_states.device)
128-
grid_w = torch.arange(post_patch_width, dtype=torch.float32, device=hidden_states.device)
129-
grid_f = torch.arange(post_patch_num_frames, dtype=torch.float32, device=hidden_states.device)
124+
grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device)
125+
grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device)
126+
grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device)
130127
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
131128
grid = torch.stack(grid, dim=0)
132129
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
@@ -374,28 +371,20 @@ def forward(
374371
encoder_hidden_states: torch.Tensor,
375372
timestep: torch.LongTensor,
376373
encoder_attention_mask: torch.Tensor,
374+
num_frames: int,
375+
height: int,
376+
width: int,
377377
rope_interpolation_scale: Optional[Tuple[float, float, float]] = None,
378378
return_dict: bool = True,
379379
) -> torch.Tensor:
380-
image_rotary_emb = self.rope(hidden_states, rope_interpolation_scale)
380+
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
381381

382382
# convert encoder_attention_mask to a bias the same way we do for attention_mask
383383
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
384384
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
385385
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
386386

387-
batch_size, num_channels, num_frames, height, width = hidden_states.shape
388-
p = self.config.patch_size
389-
p_t = self.config.patch_size_t
390-
391-
post_patch_height = height // p
392-
post_patch_width = width // p
393-
post_patch_num_frames = num_frames // p_t
394-
395-
hidden_states = hidden_states.reshape(
396-
batch_size, -1, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p
397-
)
398-
hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
387+
batch_size = hidden_states.size(0)
399388
hidden_states = self.proj_in(hidden_states)
400389

401390
temb, embedded_timestep = self.time_embed(
@@ -446,12 +435,7 @@ def custom_forward(*inputs):
446435

447436
hidden_states = self.norm_out(hidden_states)
448437
hidden_states = hidden_states * (1 + scale) + shift
449-
hidden_states = self.proj_out(hidden_states)
450-
451-
hidden_states = hidden_states.reshape(
452-
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
453-
)
454-
output = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
438+
output = self.proj_out(hidden_states)
455439

456440
if not return_dict:
457441
return (output,)

src/diffusers/pipelines/ltx/pipeline_ltx.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,34 @@ def check_inputs(
387387
f" {negative_prompt_attention_mask.shape}."
388388
)
389389

390+
@staticmethod
391+
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
392+
batch_size, num_channels, num_frames, height, width = latents.shape
393+
post_patch_num_frames = num_frames // patch_size_t
394+
post_patch_height = height // patch_size
395+
post_patch_width = width // patch_size
396+
latents = latents.reshape(
397+
batch_size,
398+
-1,
399+
post_patch_num_frames,
400+
patch_size_t,
401+
post_patch_height,
402+
patch_size,
403+
post_patch_width,
404+
patch_size,
405+
)
406+
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
407+
return latents
408+
409+
@staticmethod
410+
def _unpack_latents(
411+
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
412+
) -> torch.Tensor:
413+
batch_size, num_channels, video_sequence_length = latents.shape
414+
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
415+
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
416+
return latents
417+
390418
def prepare_latents(
391419
self,
392420
batch_size: int = 1,
@@ -415,20 +443,9 @@ def prepare_latents(
415443
)
416444

417445
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
446+
latents = self._pack_latents(latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size)
418447
return latents
419448

420-
def decode_latents(self, latents: torch.Tensor):
421-
# unscale/denormalize the latents
422-
latents_mean = self.vae.latents_mean.view(1, self.vae.config.latent_channels, 1, 1, 1).to(
423-
latents.device, latents.dtype
424-
)
425-
latents_std = self.vae.latents_std.view(1, self.vae.config.latent_channels, 1, 1, 1).to(
426-
latents.device, latents.dtype
427-
)
428-
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
429-
video = self.vae.decode(latents, return_dict=False)[0]
430-
return video
431-
432449
@property
433450
def guidance_scale(self):
434451
return self._guidance_scale
@@ -610,10 +627,10 @@ def __call__(
610627
)
611628

612629
# 5. Prepare timesteps
613-
latent_frames = latents.size(2)
630+
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
614631
latent_height = height // self.vae_spatial_compression_ratio
615632
latent_width = width // self.vae_spatial_compression_ratio
616-
video_sequence_length = latent_height * latent_width * latent_frames
633+
video_sequence_length = latent_num_frames * latent_height * latent_width
617634
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
618635
mu = calculate_shift(
619636
video_sequence_length,
@@ -656,6 +673,9 @@ def __call__(
656673
encoder_hidden_states=prompt_embeds,
657674
timestep=timestep,
658675
encoder_attention_mask=prompt_attention_mask,
676+
num_frames=latent_num_frames,
677+
height=latent_height,
678+
width=latent_width,
659679
rope_interpolation_scale=rope_interpolation_scale,
660680
return_dict=False,
661681
)[0]
@@ -689,7 +709,16 @@ def __call__(
689709
if output_type == "latent":
690710
video = latents
691711
else:
692-
video = self.decode_latents(latents)
712+
latents = self._unpack_latents(latents, latent_num_frames, latent_height, latent_width, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size)
713+
# unscale/denormalize the latents
714+
latents_mean = self.vae.latents_mean.view(1, self.vae.config.latent_channels, 1, 1, 1).to(
715+
latents.device, latents.dtype
716+
)
717+
latents_std = self.vae.latents_std.view(1, self.vae.config.latent_channels, 1, 1, 1).to(
718+
latents.device, latents.dtype
719+
)
720+
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
721+
video = self.vae.decode(latents, return_dict=False)[0]
693722
video = self.video_processor.postprocess_video(video, output_type=output_type)
694723

695724
# Offload all models

0 commit comments

Comments
 (0)