diff --git a/docs/source/en/api/pipelines/ltx_video.md b/docs/source/en/api/pipelines/ltx_video.md index 4bc22c0f9f6c..96a6028fbc2d 100644 --- a/docs/source/en/api/pipelines/ltx_video.md +++ b/docs/source/en/api/pipelines/ltx_video.md @@ -32,6 +32,7 @@ Available models: |:-------------:|:-----------------:| | [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` | | [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` | +| [`LTX Video 0.9.5`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.5.safetensors) | `torch.bfloat16` | Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository. diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index ef1fd568397f..dcfdfaf23288 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -14,7 +14,7 @@ import inspect from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import PIL.Image import torch @@ -75,6 +75,7 @@ >>> # Generate video >>> generator = torch.Generator("cuda").manual_seed(0) + >>> # Text-only conditioning is also supported without the need to pass `conditions` >>> video = pipe( ... conditions=[condition1, condition2], ... prompt=prompt, @@ -223,7 +224,7 @@ def retrieve_latents( class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin): r""" - Pipeline for image-to-video generation. + Pipeline for text/image/video-to-video generation. Reference: https://github.com/Lightricks/LTX-Video @@ -482,9 +483,6 @@ def check_inputs( if conditions is not None and (image is not None or video is not None): raise ValueError("If `conditions` is provided, `image` and `video` must not be provided.") - if conditions is None and (image is None and video is None): - raise ValueError("If `conditions` is not provided, `image` or `video` must be provided.") - if conditions is None: if isinstance(image, list) and isinstance(frame_index, list) and len(image) != len(frame_index): raise ValueError( @@ -642,9 +640,9 @@ def add_noise_to_image_conditioning_latents( def prepare_latents( self, - conditions: List[torch.Tensor], - condition_strength: List[float], - condition_frame_index: List[int], + conditions: Optional[List[torch.Tensor]] = None, + condition_strength: Optional[List[float]] = None, + condition_frame_index: Optional[List[int]] = None, batch_size: int = 1, num_channels_latents: int = 128, height: int = 512, @@ -654,7 +652,7 @@ def prepare_latents( generator: Optional[torch.Generator] = None, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - ) -> None: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: num_latent_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 latent_height = height // self.vae_spatial_compression_ratio latent_width = width // self.vae_spatial_compression_ratio @@ -662,77 +660,80 @@ def prepare_latents( shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=torch.float32) - - extra_conditioning_latents = [] - extra_conditioning_video_ids = [] - extra_conditioning_mask = [] - extra_conditioning_num_latents = 0 - for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index): - condition_latents = retrieve_latents(self.vae.encode(data), generator=generator) - condition_latents = self._normalize_latents( - condition_latents, self.vae.latents_mean, self.vae.latents_std - ).to(device, dtype=dtype) - - num_data_frames = data.size(2) - num_cond_frames = condition_latents.size(2) - - if frame_index == 0: - latents[:, :, :num_cond_frames] = torch.lerp( - latents[:, :, :num_cond_frames], condition_latents, strength - ) - condition_latent_frames_mask[:, :num_cond_frames] = strength + if len(conditions) > 0: + condition_latent_frames_mask = torch.zeros( + (batch_size, num_latent_frames), device=device, dtype=torch.float32 + ) - else: - if num_data_frames > 1: - if num_cond_frames < num_prefix_latent_frames: - raise ValueError( - f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_data_frames}." - ) - - if num_cond_frames > num_prefix_latent_frames: - start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames - end_frame = start_frame + num_cond_frames - num_prefix_latent_frames - latents[:, :, start_frame:end_frame] = torch.lerp( - latents[:, :, start_frame:end_frame], - condition_latents[:, :, num_prefix_latent_frames:], - strength, - ) - condition_latent_frames_mask[:, start_frame:end_frame] = strength - condition_latents = condition_latents[:, :, :num_prefix_latent_frames] - - noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype) - condition_latents = torch.lerp(noise, condition_latents, strength) - - condition_video_ids = self._prepare_video_ids( - batch_size, - condition_latents.size(2), - latent_height, - latent_width, - patch_size=self.transformer_spatial_patch_size, - patch_size_t=self.transformer_temporal_patch_size, - device=device, - ) - condition_video_ids = self._scale_video_ids( - condition_video_ids, - scale_factor=self.vae_spatial_compression_ratio, - scale_factor_t=self.vae_temporal_compression_ratio, - frame_index=frame_index, - device=device, - ) - condition_latents = self._pack_latents( - condition_latents, - self.transformer_spatial_patch_size, - self.transformer_temporal_patch_size, - ) - condition_conditioning_mask = torch.full( - condition_latents.shape[:2], strength, device=device, dtype=dtype - ) + extra_conditioning_latents = [] + extra_conditioning_video_ids = [] + extra_conditioning_mask = [] + extra_conditioning_num_latents = 0 + for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index): + condition_latents = retrieve_latents(self.vae.encode(data), generator=generator) + condition_latents = self._normalize_latents( + condition_latents, self.vae.latents_mean, self.vae.latents_std + ).to(device, dtype=dtype) + + num_data_frames = data.size(2) + num_cond_frames = condition_latents.size(2) + + if frame_index == 0: + latents[:, :, :num_cond_frames] = torch.lerp( + latents[:, :, :num_cond_frames], condition_latents, strength + ) + condition_latent_frames_mask[:, :num_cond_frames] = strength + + else: + if num_data_frames > 1: + if num_cond_frames < num_prefix_latent_frames: + raise ValueError( + f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_data_frames}." + ) + + if num_cond_frames > num_prefix_latent_frames: + start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames + end_frame = start_frame + num_cond_frames - num_prefix_latent_frames + latents[:, :, start_frame:end_frame] = torch.lerp( + latents[:, :, start_frame:end_frame], + condition_latents[:, :, num_prefix_latent_frames:], + strength, + ) + condition_latent_frames_mask[:, start_frame:end_frame] = strength + condition_latents = condition_latents[:, :, :num_prefix_latent_frames] + + noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype) + condition_latents = torch.lerp(noise, condition_latents, strength) + + condition_video_ids = self._prepare_video_ids( + batch_size, + condition_latents.size(2), + latent_height, + latent_width, + patch_size=self.transformer_spatial_patch_size, + patch_size_t=self.transformer_temporal_patch_size, + device=device, + ) + condition_video_ids = self._scale_video_ids( + condition_video_ids, + scale_factor=self.vae_spatial_compression_ratio, + scale_factor_t=self.vae_temporal_compression_ratio, + frame_index=frame_index, + device=device, + ) + condition_latents = self._pack_latents( + condition_latents, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + condition_conditioning_mask = torch.full( + condition_latents.shape[:2], strength, device=device, dtype=dtype + ) - extra_conditioning_latents.append(condition_latents) - extra_conditioning_video_ids.append(condition_video_ids) - extra_conditioning_mask.append(condition_conditioning_mask) - extra_conditioning_num_latents += condition_latents.size(1) + extra_conditioning_latents.append(condition_latents) + extra_conditioning_video_ids.append(condition_video_ids) + extra_conditioning_mask.append(condition_conditioning_mask) + extra_conditioning_num_latents += condition_latents.size(1) video_ids = self._prepare_video_ids( batch_size, @@ -743,7 +744,10 @@ def prepare_latents( patch_size=self.transformer_spatial_patch_size, device=device, ) - conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0]) + if len(conditions) > 0: + conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0]) + else: + conditioning_mask, extra_conditioning_num_latents = None, 0 video_ids = self._scale_video_ids( video_ids, scale_factor=self.vae_spatial_compression_ratio, @@ -755,7 +759,7 @@ def prepare_latents( latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size ) - if len(extra_conditioning_latents) > 0: + if len(conditions) > 0 and len(extra_conditioning_latents) > 0: latents = torch.cat([*extra_conditioning_latents, latents], dim=1) video_ids = torch.cat([*extra_conditioning_video_ids, video_ids], dim=2) conditioning_mask = torch.cat([*extra_conditioning_mask, conditioning_mask], dim=1) @@ -955,7 +959,7 @@ def __call__( frame_index = [condition.frame_index for condition in conditions] image = [condition.image for condition in conditions] video = [condition.video for condition in conditions] - else: + elif image is not None or video is not None: if not isinstance(image, list): image = [image] num_conditions = 1 @@ -999,32 +1003,34 @@ def __call__( vae_dtype = self.vae.dtype conditioning_tensors = [] - for condition_image, condition_video, condition_frame_index, condition_strength in zip( - image, video, frame_index, strength - ): - if condition_image is not None: - condition_tensor = ( - self.video_processor.preprocess(condition_image, height, width) - .unsqueeze(2) - .to(device, dtype=vae_dtype) - ) - elif condition_video is not None: - condition_tensor = self.video_processor.preprocess_video(condition_video, height, width) - num_frames_input = condition_tensor.size(2) - num_frames_output = self.trim_conditioning_sequence( - condition_frame_index, num_frames_input, num_frames - ) - condition_tensor = condition_tensor[:, :, :num_frames_output] - condition_tensor = condition_tensor.to(device, dtype=vae_dtype) - else: - raise ValueError("Either `image` or `video` must be provided in the `LTXVideoCondition`.") - - if condition_tensor.size(2) % self.vae_temporal_compression_ratio != 1: - raise ValueError( - f"Number of frames in the video must be of the form (k * {self.vae_temporal_compression_ratio} + 1) " - f"but got {condition_tensor.size(2)} frames." - ) - conditioning_tensors.append(condition_tensor) + is_conditioning_image_or_video = image is not None or video is not None + if is_conditioning_image_or_video: + for condition_image, condition_video, condition_frame_index, condition_strength in zip( + image, video, frame_index, strength + ): + if condition_image is not None: + condition_tensor = ( + self.video_processor.preprocess(condition_image, height, width) + .unsqueeze(2) + .to(device, dtype=vae_dtype) + ) + elif condition_video is not None: + condition_tensor = self.video_processor.preprocess_video(condition_video, height, width) + num_frames_input = condition_tensor.size(2) + num_frames_output = self.trim_conditioning_sequence( + condition_frame_index, num_frames_input, num_frames + ) + condition_tensor = condition_tensor[:, :, :num_frames_output] + condition_tensor = condition_tensor.to(device, dtype=vae_dtype) + else: + raise ValueError("Either `image` or `video` must be provided for conditioning.") + + if condition_tensor.size(2) % self.vae_temporal_compression_ratio != 1: + raise ValueError( + f"Number of frames in the video must be of the form (k * {self.vae_temporal_compression_ratio} + 1) " + f"but got {condition_tensor.size(2)} frames." + ) + conditioning_tensors.append(condition_tensor) # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels @@ -1045,7 +1051,7 @@ def __call__( video_coords = video_coords.float() video_coords[:, 0] = video_coords[:, 0] * (1.0 / frame_rate) - init_latents = latents.clone() + init_latents = latents.clone() if is_conditioning_image_or_video else None if self.do_classifier_free_guidance: video_coords = torch.cat([video_coords, video_coords], dim=0) @@ -1065,7 +1071,7 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) - # 7. Denoising loop + # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: @@ -1073,7 +1079,7 @@ def __call__( self._current_timestep = t - if image_cond_noise_scale > 0: + if image_cond_noise_scale > 0 and init_latents is not None: # Add timestep-dependent noise to the hard-conditioning latents # This helps with motion continuity, especially when conditioned on a single frame latents = self.add_noise_to_image_conditioning_latents( @@ -1086,16 +1092,18 @@ def __call__( ) latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - conditioning_mask_model_input = ( - torch.cat([conditioning_mask, conditioning_mask]) - if self.do_classifier_free_guidance - else conditioning_mask - ) + if is_conditioning_image_or_video: + conditioning_mask_model_input = ( + torch.cat([conditioning_mask, conditioning_mask]) + if self.do_classifier_free_guidance + else conditioning_mask + ) latent_model_input = latent_model_input.to(prompt_embeds.dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float() - timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) + if is_conditioning_image_or_video: + timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) noise_pred = self.transformer( hidden_states=latent_model_input, @@ -1115,8 +1123,11 @@ def __call__( denoised_latents = self.scheduler.step( -noise_pred, t, latents, per_token_timesteps=timestep, return_dict=False )[0] - tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1) - latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents) + if is_conditioning_image_or_video: + tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1) + latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents) + else: + latents = denoised_latents if callback_on_step_end is not None: callback_kwargs = {} @@ -1134,7 +1145,9 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - latents = latents[:, extra_conditioning_num_latents:] + if is_conditioning_image_or_video: + latents = latents[:, extra_conditioning_num_latents:] + latents = self._unpack_latents( latents, latent_num_frames,