diff --git a/packages/ltx-pipelines/src/ltx_pipelines/__init__.py b/packages/ltx-pipelines/src/ltx_pipelines/__init__.py index a4fd05e..a3dfbac 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/__init__.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/__init__.py @@ -16,11 +16,16 @@ from ltx_pipelines.keyframe_interpolation import KeyframeInterpolationPipeline from ltx_pipelines.ti2vid_one_stage import TI2VidOneStagePipeline from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline - +from ltx_pipelines.ti2vid_two_stages_deterministic import ( + TI2VidTwoStagesPipelineDeterministic, + enable_deterministic_mode, +) __all__ = [ "DistilledPipeline", "ICLoraPipeline", "KeyframeInterpolationPipeline", "TI2VidOneStagePipeline", "TI2VidTwoStagesPipeline", + "TI2VidTwoStagesPipelineDeterministic", + "enable_deterministic_mode", ] diff --git a/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages_deterministic.py b/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages_deterministic.py new file mode 100644 index 0000000..d1a6ad8 --- /dev/null +++ b/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages_deterministic.py @@ -0,0 +1,304 @@ +import logging +import os +from collections.abc import Iterator + +import torch + +from ltx_core.components.diffusion_steps import EulerDiffusionStep +from ltx_core.components.guiders import CFGGuider +from ltx_core.components.noisers import GaussianNoiser +from ltx_core.components.protocols import DiffusionStepProtocol +from ltx_core.components.schedulers import LTX2Scheduler +from ltx_core.loader import LoraPathStrengthAndSDOps +from ltx_core.model.audio_vae import decode_audio as vae_decode_audio +from ltx_core.model.upsampler import upsample_video +from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number +from ltx_core.model.video_vae import decode_video as vae_decode_video +from ltx_core.text_encoders.gemma import encode_text +from ltx_core.types import LatentState, VideoPixelShape +from ltx_pipelines.utils import ModelLedger +from ltx_pipelines.utils.args import default_2_stage_arg_parser +from ltx_pipelines.utils.constants import ( + AUDIO_SAMPLE_RATE, + STAGE_2_DISTILLED_SIGMA_VALUES, +) +from ltx_pipelines.utils.helpers import ( + assert_resolution, + cleanup_memory, + denoise_audio_video, + euler_denoising_loop, + generate_enhanced_prompt, + get_device, + guider_denoising_func, + image_conditionings_by_replacing_latent, + simple_denoising_func, +) +from ltx_pipelines.utils.media_io import encode_video +from ltx_pipelines.utils.types import PipelineComponents + +device = get_device() + + +def enable_deterministic_mode(seed: int) -> None: + """ + Enable deterministic mode for reproducible inference. + This sets various PyTorch and CUDA settings to ensure reproducibility. + Note: This may impact performance. + """ + # Set Python and PyTorch seeds + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + # Enable deterministic algorithms + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + # Use deterministic algorithms where available (may raise errors for some ops) + torch.use_deterministic_algorithms(True, warn_only=True) + + # Set CUBLAS workspace config for deterministic matrix operations + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + + +class TI2VidTwoStagesPipelineDeterministic: + """ + Two-stage text/image-to-video generation pipeline. + Stage 1 generates video at the target resolution with CFG guidance, then + Stage 2 upsamples by 2x and refines using a distilled LoRA for higher + quality output. Supports optional image conditioning via the images parameter. + """ + + def __init__( + self, + checkpoint_path: str, + distilled_lora: list[LoraPathStrengthAndSDOps], + spatial_upsampler_path: str, + gemma_root: str, + loras: list[LoraPathStrengthAndSDOps], + device: str = device, + fp8transformer: bool = False, + ): + self.device = device + self.dtype = torch.bfloat16 + self.stage_1_model_ledger = ModelLedger( + dtype=self.dtype, + device=device, + checkpoint_path=checkpoint_path, + gemma_root_path=gemma_root, + spatial_upsampler_path=spatial_upsampler_path, + loras=loras, + fp8transformer=fp8transformer, + ) + + self.stage_2_model_ledger = self.stage_1_model_ledger.with_loras( + loras=distilled_lora, + ) + + self.pipeline_components = PipelineComponents( + dtype=self.dtype, + device=device, + ) + + @torch.inference_mode() + def __call__( # noqa: PLR0913 + self, + prompt: str, + negative_prompt: str, + seed: int, + height: int, + width: int, + num_frames: int, + frame_rate: float, + num_inference_steps: int, + cfg_guidance_scale: float, + images: list[tuple[str, int, float]], + tiling_config: TilingConfig | None = None, + enhance_prompt: bool = False, + ) -> tuple[Iterator[torch.Tensor], torch.Tensor]: + assert_resolution(height=height, width=width, is_two_stage=True) + + generator = torch.Generator(device=self.device).manual_seed(seed) + noiser = GaussianNoiser(generator=generator) + stepper = EulerDiffusionStep() + cfg_guider = CFGGuider(cfg_guidance_scale) + dtype = torch.bfloat16 + + text_encoder = self.stage_1_model_ledger.text_encoder() + if enhance_prompt: + prompt = generate_enhanced_prompt( + text_encoder, prompt, images[0][0] if len(images) > 0 else None, seed=seed + ) + context_p, context_n = encode_text(text_encoder, prompts=[prompt, negative_prompt]) + v_context_p, a_context_p = context_p + v_context_n, a_context_n = context_n + + torch.cuda.synchronize() + del text_encoder + cleanup_memory() + + # Stage 1: Initial low resolution video generation. + video_encoder = self.stage_1_model_ledger.video_encoder() + transformer = self.stage_1_model_ledger.transformer() + sigmas = LTX2Scheduler().execute(steps=num_inference_steps).to(dtype=torch.float32, device=self.device) + + def first_stage_denoising_loop( + sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol + ) -> tuple[LatentState, LatentState]: + return euler_denoising_loop( + sigmas=sigmas, + video_state=video_state, + audio_state=audio_state, + stepper=stepper, + denoise_fn=guider_denoising_func( + cfg_guider, + v_context_p, + v_context_n, + a_context_p, + a_context_n, + transformer=transformer, # noqa: F821 + ), + ) + + stage_1_output_shape = VideoPixelShape( + batch=1, + frames=num_frames, + width=width // 2, + height=height // 2, + fps=frame_rate, + ) + stage_1_conditionings = image_conditionings_by_replacing_latent( + images=images, + height=stage_1_output_shape.height, + width=stage_1_output_shape.width, + video_encoder=video_encoder, + dtype=dtype, + device=self.device, + ) + video_state, audio_state = denoise_audio_video( + output_shape=stage_1_output_shape, + conditionings=stage_1_conditionings, + noiser=noiser, + sigmas=sigmas, + stepper=stepper, + denoising_loop_fn=first_stage_denoising_loop, + components=self.pipeline_components, + dtype=dtype, + device=self.device, + ) + + torch.cuda.synchronize() + del transformer + cleanup_memory() + + # Stage 2: Upsample and refine the video at higher resolution with distilled LORA. + upscaled_video_latent = upsample_video( + latent=video_state.latent[:1], + video_encoder=video_encoder, + upsampler=self.stage_2_model_ledger.spatial_upsampler(), + ) + + torch.cuda.synchronize() + cleanup_memory() + + transformer = self.stage_2_model_ledger.transformer() + distilled_sigmas = torch.Tensor(STAGE_2_DISTILLED_SIGMA_VALUES).to(self.device) + + def second_stage_denoising_loop( + sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol + ) -> tuple[LatentState, LatentState]: + return euler_denoising_loop( + sigmas=sigmas, + video_state=video_state, + audio_state=audio_state, + stepper=stepper, + denoise_fn=simple_denoising_func( + video_context=v_context_p, + audio_context=a_context_p, + transformer=transformer, # noqa: F821 + ), + ) + + stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate) + stage_2_conditionings = image_conditionings_by_replacing_latent( + images=images, + height=stage_2_output_shape.height, + width=stage_2_output_shape.width, + video_encoder=video_encoder, + dtype=dtype, + device=self.device, + ) + video_state, audio_state = denoise_audio_video( + output_shape=stage_2_output_shape, + conditionings=stage_2_conditionings, + noiser=noiser, + sigmas=distilled_sigmas, + stepper=stepper, + denoising_loop_fn=second_stage_denoising_loop, + components=self.pipeline_components, + dtype=dtype, + device=self.device, + noise_scale=distilled_sigmas[0], + initial_video_latent=upscaled_video_latent, + initial_audio_latent=audio_state.latent, + ) + + torch.cuda.synchronize() + del transformer + del video_encoder + cleanup_memory() + + decoded_video = vae_decode_video(video_state.latent, self.stage_2_model_ledger.video_decoder(), tiling_config) + decoded_audio = vae_decode_audio( + audio_state.latent, self.stage_2_model_ledger.audio_decoder(), self.stage_2_model_ledger.vocoder() + ) + + return decoded_video, decoded_audio + + +@torch.inference_mode() +def main() -> None: + logging.getLogger().setLevel(logging.INFO) + parser = default_2_stage_arg_parser() + args = parser.parse_args() + + # Enable deterministic mode for reproducible outputs with the same seed + enable_deterministic_mode(args.seed) + + pipeline = TI2VidTwoStagesPipelineDeterministic( + checkpoint_path=args.checkpoint_path, + distilled_lora=args.distilled_lora, + spatial_upsampler_path=args.spatial_upsampler_path, + gemma_root=args.gemma_root, + loras=args.lora, + fp8transformer=args.enable_fp8, + ) + tiling_config = TilingConfig.default() + video_chunks_number = get_video_chunks_number(args.num_frames, tiling_config) + video, audio = pipeline( + prompt=args.prompt, + negative_prompt=args.negative_prompt, + seed=args.seed, + height=args.height, + width=args.width, + num_frames=args.num_frames, + frame_rate=args.frame_rate, + num_inference_steps=args.num_inference_steps, + cfg_guidance_scale=args.cfg_guidance_scale, + images=args.images, + tiling_config=tiling_config, + ) + + encode_video( + video=video, + fps=args.frame_rate, + audio=audio, + audio_sample_rate=AUDIO_SAMPLE_RATE, + output_path=args.output_path, + video_chunks_number=video_chunks_number, + ) + + +if __name__ == "__main__": + main()