Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion packages/ltx-pipelines/src/ltx_pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@
from ltx_pipelines.keyframe_interpolation import KeyframeInterpolationPipeline
from ltx_pipelines.ti2vid_one_stage import TI2VidOneStagePipeline
from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline

from ltx_pipelines.ti2vid_two_stages_deterministic import (
TI2VidTwoStagesPipelineDeterministic,
enable_deterministic_mode,
)
__all__ = [
"DistilledPipeline",
"ICLoraPipeline",
"KeyframeInterpolationPipeline",
"TI2VidOneStagePipeline",
"TI2VidTwoStagesPipeline",
"TI2VidTwoStagesPipelineDeterministic",
"enable_deterministic_mode",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
import logging
import os
from collections.abc import Iterator

import torch

from ltx_core.components.diffusion_steps import EulerDiffusionStep
from ltx_core.components.guiders import CFGGuider
from ltx_core.components.noisers import GaussianNoiser
from ltx_core.components.protocols import DiffusionStepProtocol
from ltx_core.components.schedulers import LTX2Scheduler
from ltx_core.loader import LoraPathStrengthAndSDOps
from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
from ltx_core.model.upsampler import upsample_video
from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
from ltx_core.model.video_vae import decode_video as vae_decode_video
from ltx_core.text_encoders.gemma import encode_text
from ltx_core.types import LatentState, VideoPixelShape
from ltx_pipelines.utils import ModelLedger
from ltx_pipelines.utils.args import default_2_stage_arg_parser
from ltx_pipelines.utils.constants import (
AUDIO_SAMPLE_RATE,
STAGE_2_DISTILLED_SIGMA_VALUES,
)
from ltx_pipelines.utils.helpers import (
assert_resolution,
cleanup_memory,
denoise_audio_video,
euler_denoising_loop,
generate_enhanced_prompt,
get_device,
guider_denoising_func,
image_conditionings_by_replacing_latent,
simple_denoising_func,
)
from ltx_pipelines.utils.media_io import encode_video
from ltx_pipelines.utils.types import PipelineComponents

device = get_device()


def enable_deterministic_mode(seed: int) -> None:
"""
Enable deterministic mode for reproducible inference.
This sets various PyTorch and CUDA settings to ensure reproducibility.
Note: This may impact performance.
"""
# Set Python and PyTorch seeds
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# Enable deterministic algorithms
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Use deterministic algorithms where available (may raise errors for some ops)
torch.use_deterministic_algorithms(True, warn_only=True)

# Set CUBLAS workspace config for deterministic matrix operations
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"


class TI2VidTwoStagesPipelineDeterministic:
"""
Two-stage text/image-to-video generation pipeline.
Stage 1 generates video at the target resolution with CFG guidance, then
Stage 2 upsamples by 2x and refines using a distilled LoRA for higher
quality output. Supports optional image conditioning via the images parameter.
"""

def __init__(
self,
checkpoint_path: str,
distilled_lora: list[LoraPathStrengthAndSDOps],
spatial_upsampler_path: str,
gemma_root: str,
loras: list[LoraPathStrengthAndSDOps],
device: str = device,
fp8transformer: bool = False,
):
self.device = device
self.dtype = torch.bfloat16
self.stage_1_model_ledger = ModelLedger(
dtype=self.dtype,
device=device,
checkpoint_path=checkpoint_path,
gemma_root_path=gemma_root,
spatial_upsampler_path=spatial_upsampler_path,
loras=loras,
fp8transformer=fp8transformer,
)

self.stage_2_model_ledger = self.stage_1_model_ledger.with_loras(
loras=distilled_lora,
)

self.pipeline_components = PipelineComponents(
dtype=self.dtype,
device=device,
)

@torch.inference_mode()
def __call__( # noqa: PLR0913
self,
prompt: str,
negative_prompt: str,
seed: int,
height: int,
width: int,
num_frames: int,
frame_rate: float,
num_inference_steps: int,
cfg_guidance_scale: float,
images: list[tuple[str, int, float]],
tiling_config: TilingConfig | None = None,
enhance_prompt: bool = False,
) -> tuple[Iterator[torch.Tensor], torch.Tensor]:
assert_resolution(height=height, width=width, is_two_stage=True)

generator = torch.Generator(device=self.device).manual_seed(seed)
noiser = GaussianNoiser(generator=generator)
stepper = EulerDiffusionStep()
cfg_guider = CFGGuider(cfg_guidance_scale)
dtype = torch.bfloat16

text_encoder = self.stage_1_model_ledger.text_encoder()
if enhance_prompt:
prompt = generate_enhanced_prompt(
text_encoder, prompt, images[0][0] if len(images) > 0 else None, seed=seed
)
context_p, context_n = encode_text(text_encoder, prompts=[prompt, negative_prompt])
v_context_p, a_context_p = context_p
v_context_n, a_context_n = context_n

torch.cuda.synchronize()
del text_encoder
cleanup_memory()

# Stage 1: Initial low resolution video generation.
video_encoder = self.stage_1_model_ledger.video_encoder()
transformer = self.stage_1_model_ledger.transformer()
sigmas = LTX2Scheduler().execute(steps=num_inference_steps).to(dtype=torch.float32, device=self.device)

def first_stage_denoising_loop(
sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol
) -> tuple[LatentState, LatentState]:
return euler_denoising_loop(
sigmas=sigmas,
video_state=video_state,
audio_state=audio_state,
stepper=stepper,
denoise_fn=guider_denoising_func(
cfg_guider,
v_context_p,
v_context_n,
a_context_p,
a_context_n,
transformer=transformer, # noqa: F821
),
)

stage_1_output_shape = VideoPixelShape(
batch=1,
frames=num_frames,
width=width // 2,
height=height // 2,
fps=frame_rate,
)
stage_1_conditionings = image_conditionings_by_replacing_latent(
images=images,
height=stage_1_output_shape.height,
width=stage_1_output_shape.width,
video_encoder=video_encoder,
dtype=dtype,
device=self.device,
)
video_state, audio_state = denoise_audio_video(
output_shape=stage_1_output_shape,
conditionings=stage_1_conditionings,
noiser=noiser,
sigmas=sigmas,
stepper=stepper,
denoising_loop_fn=first_stage_denoising_loop,
components=self.pipeline_components,
dtype=dtype,
device=self.device,
)

torch.cuda.synchronize()
del transformer
cleanup_memory()

# Stage 2: Upsample and refine the video at higher resolution with distilled LORA.
upscaled_video_latent = upsample_video(
latent=video_state.latent[:1],
video_encoder=video_encoder,
upsampler=self.stage_2_model_ledger.spatial_upsampler(),
)

torch.cuda.synchronize()
cleanup_memory()

transformer = self.stage_2_model_ledger.transformer()
distilled_sigmas = torch.Tensor(STAGE_2_DISTILLED_SIGMA_VALUES).to(self.device)

def second_stage_denoising_loop(
sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol
) -> tuple[LatentState, LatentState]:
return euler_denoising_loop(
sigmas=sigmas,
video_state=video_state,
audio_state=audio_state,
stepper=stepper,
denoise_fn=simple_denoising_func(
video_context=v_context_p,
audio_context=a_context_p,
transformer=transformer, # noqa: F821
),
)

stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
stage_2_conditionings = image_conditionings_by_replacing_latent(
images=images,
height=stage_2_output_shape.height,
width=stage_2_output_shape.width,
video_encoder=video_encoder,
dtype=dtype,
device=self.device,
)
video_state, audio_state = denoise_audio_video(
output_shape=stage_2_output_shape,
conditionings=stage_2_conditionings,
noiser=noiser,
sigmas=distilled_sigmas,
stepper=stepper,
denoising_loop_fn=second_stage_denoising_loop,
components=self.pipeline_components,
dtype=dtype,
device=self.device,
noise_scale=distilled_sigmas[0],
initial_video_latent=upscaled_video_latent,
initial_audio_latent=audio_state.latent,
)

torch.cuda.synchronize()
del transformer
del video_encoder
cleanup_memory()

decoded_video = vae_decode_video(video_state.latent, self.stage_2_model_ledger.video_decoder(), tiling_config)
decoded_audio = vae_decode_audio(
audio_state.latent, self.stage_2_model_ledger.audio_decoder(), self.stage_2_model_ledger.vocoder()
)

return decoded_video, decoded_audio


@torch.inference_mode()
def main() -> None:
logging.getLogger().setLevel(logging.INFO)
parser = default_2_stage_arg_parser()
args = parser.parse_args()

# Enable deterministic mode for reproducible outputs with the same seed
enable_deterministic_mode(args.seed)

pipeline = TI2VidTwoStagesPipelineDeterministic(
checkpoint_path=args.checkpoint_path,
distilled_lora=args.distilled_lora,
spatial_upsampler_path=args.spatial_upsampler_path,
gemma_root=args.gemma_root,
loras=args.lora,
fp8transformer=args.enable_fp8,
)
tiling_config = TilingConfig.default()
video_chunks_number = get_video_chunks_number(args.num_frames, tiling_config)
video, audio = pipeline(
prompt=args.prompt,
negative_prompt=args.negative_prompt,
seed=args.seed,
height=args.height,
width=args.width,
num_frames=args.num_frames,
frame_rate=args.frame_rate,
num_inference_steps=args.num_inference_steps,
cfg_guidance_scale=args.cfg_guidance_scale,
images=args.images,
tiling_config=tiling_config,
)

encode_video(
video=video,
fps=args.frame_rate,
audio=audio,
audio_sample_rate=AUDIO_SAMPLE_RATE,
output_path=args.output_path,
video_chunks_number=video_chunks_number,
)


if __name__ == "__main__":
main()