Skip to content

Commit 4e8b2a4

Browse files
committed
address review comments
1 parent da475ec commit 4e8b2a4

File tree

2 files changed

+35
-26
lines changed

2 files changed

+35
-26
lines changed

src/diffusers/pipelines/ltx/pipeline_ltx.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,6 @@ def __init__(
197197
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 128
198198
)
199199

200-
self.default_height = 512
201-
self.default_width = 704
202-
self.default_frames = 121
203-
204200
# Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline._get_t5_prompt_embeds with 256->128
205201
def _get_t5_prompt_embeds(
206202
self,
@@ -389,6 +385,10 @@ def check_inputs(
389385

390386
@staticmethod
391387
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
388+
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
389+
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
390+
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
391+
# dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
392392
batch_size, num_channels, num_frames, height, width = latents.shape
393393
post_patch_num_frames = num_frames // patch_size_t
394394
post_patch_height = height // patch_size
@@ -410,7 +410,10 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int
410410
def _unpack_latents(
411411
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
412412
) -> torch.Tensor:
413-
batch_size, num_channels, video_sequence_length = latents.shape
413+
# Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
414+
# are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
415+
# what happens in the `_pack_latents` method.
416+
batch_size = latents.size(0)
414417
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
415418
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
416419
return latents
@@ -419,6 +422,7 @@ def _unpack_latents(
419422
def _normalize_latents(
420423
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
421424
) -> torch.Tensor:
425+
# Normalize latents across the channel dimension [B, C, F, H, W]
422426
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
423427
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
424428
latents = (latents - latents_mean) * scaling_factor / latents_std
@@ -428,6 +432,7 @@ def _normalize_latents(
428432
def _denormalize_latents(
429433
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
430434
) -> torch.Tensor:
435+
# Denormalize latents across the channel dimension [B, C, F, H, W]
431436
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
432437
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
433438
latents = latents * latents_std / scaling_factor + latents_mean
@@ -488,9 +493,9 @@ def __call__(
488493
self,
489494
prompt: Union[str, List[str]] = None,
490495
negative_prompt: Optional[Union[str, List[str]]] = None,
491-
height: Optional[int] = None,
492-
width: Optional[int] = None,
493-
num_frames: int = 81,
496+
height: int = 512,
497+
width: int = 704,
498+
num_frames: int = 161,
494499
frame_rate: int = 25,
495500
num_inference_steps: int = 50,
496501
timesteps: List[int] = None,
@@ -515,11 +520,11 @@ def __call__(
515520
prompt (`str` or `List[str]`, *optional*):
516521
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
517522
instead.
518-
height (`int`, *optional*, defaults to `self.default_height`):
523+
height (`int`, defaults to `512`):
519524
The height in pixels of the generated image. This is set to 480 by default for the best results.
520-
width (`int`, *optional*, defaults to `self.default_width`):
525+
width (`int`, defaults to `704`):
521526
The width in pixels of the generated image. This is set to 848 by default for the best results.
522-
num_frames (`int`, defaults to `81 `):
527+
num_frames (`int`, defaults to `161`):
523528
The number of video frames to generate
524529
num_inference_steps (`int`, *optional*, defaults to 50):
525530
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -581,10 +586,6 @@ def __call__(
581586
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
582587
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
583588

584-
height = height or self.default_height
585-
width = width or self.default_width
586-
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
587-
588589
# 1. Check inputs. Raise error if not correct
589590
self.check_inputs(
590591
prompt=prompt,
@@ -671,6 +672,7 @@ def __call__(
671672
self._num_timesteps = len(timesteps)
672673

673674
# 6. Prepare micro-conditions
675+
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
674676
rope_interpolation_scale = (
675677
1 / latent_frame_rate,
676678
self.vae_spatial_compression_ratio,

src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ def encode_prompt(
353353

354354
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
355355

356+
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.check_inputs
356357
def check_inputs(
357358
self,
358359
prompt,
@@ -409,6 +410,10 @@ def check_inputs(
409410
@staticmethod
410411
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
411412
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
413+
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
414+
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
415+
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
416+
# dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
412417
batch_size, num_channels, num_frames, height, width = latents.shape
413418
post_patch_num_frames = num_frames // patch_size_t
414419
post_patch_height = height // patch_size
@@ -431,7 +436,10 @@ def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int
431436
def _unpack_latents(
432437
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
433438
) -> torch.Tensor:
434-
batch_size, num_channels, video_sequence_length = latents.shape
439+
# Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
440+
# are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
441+
# what happens in the `_pack_latents` method.
442+
batch_size = latents.size(0)
435443
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
436444
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
437445
return latents
@@ -441,6 +449,7 @@ def _unpack_latents(
441449
def _normalize_latents(
442450
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
443451
) -> torch.Tensor:
452+
# Normalize latents across the channel dimension [B, C, F, H, W]
444453
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
445454
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
446455
latents = (latents - latents_mean) * scaling_factor / latents_std
@@ -451,6 +460,7 @@ def _normalize_latents(
451460
def _denormalize_latents(
452461
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
453462
) -> torch.Tensor:
463+
# Denormalize latents across the channel dimension [B, C, F, H, W]
454464
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
455465
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
456466
latents = latents * latents_std / scaling_factor + latents_mean
@@ -543,9 +553,9 @@ def __call__(
543553
image: PipelineImageInput = None,
544554
prompt: Union[str, List[str]] = None,
545555
negative_prompt: Optional[Union[str, List[str]]] = None,
546-
height: Optional[int] = None,
547-
width: Optional[int] = None,
548-
num_frames: int = 81,
556+
height: int = 512,
557+
width: int = 704,
558+
num_frames: int = 161,
549559
frame_rate: int = 25,
550560
num_inference_steps: int = 50,
551561
timesteps: List[int] = None,
@@ -572,11 +582,11 @@ def __call__(
572582
prompt (`str` or `List[str]`, *optional*):
573583
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
574584
instead.
575-
height (`int`, *optional*, defaults to `self.default_height`):
585+
height (`int`, defaults to `512`):
576586
The height in pixels of the generated image. This is set to 480 by default for the best results.
577-
width (`int`, *optional*, defaults to `self.default_width`):
587+
width (`int`, defaults to `704`):
578588
The width in pixels of the generated image. This is set to 848 by default for the best results.
579-
num_frames (`int`, defaults to `81 `):
589+
num_frames (`int`, defaults to `161`):
580590
The number of video frames to generate
581591
num_inference_steps (`int`, *optional*, defaults to 50):
582592
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -638,10 +648,6 @@ def __call__(
638648
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
639649
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
640650

641-
height = height or self.default_height
642-
width = width or self.default_width
643-
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
644-
645651
# 1. Check inputs. Raise error if not correct
646652
self.check_inputs(
647653
prompt=prompt,
@@ -736,6 +742,7 @@ def __call__(
736742
self._num_timesteps = len(timesteps)
737743

738744
# 6. Prepare micro-conditions
745+
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
739746
rope_interpolation_scale = (
740747
1 / latent_frame_rate,
741748
self.vae_spatial_compression_ratio,

0 commit comments

Comments
 (0)