Skip to content

Commit 47d93ce

Browse files
committed
update
1 parent cfd6ec7 commit 47d93ce

File tree

3 files changed

+142
-4
lines changed

3 files changed

+142
-4
lines changed

docs/source/en/api/pipelines/ltx_video.md

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,8 @@ export_to_video(video, "output.mp4", fps=24)
254254
pipeline.vae.enable_tiling()
255255

256256
def round_to_nearest_resolution_acceptable_by_vae(height, width):
257-
height = height - (height % pipeline.vae_temporal_compression_ratio)
258-
width = width - (width % pipeline.vae_temporal_compression_ratio)
257+
height = height - (height % pipeline.vae_spatial_compression_ratio)
258+
width = width - (width % pipeline.vae_spatial_compression_ratio)
259259
return height, width
260260

261261
prompt = """
@@ -325,6 +325,95 @@ export_to_video(video, "output.mp4", fps=24)
325325

326326
</details>
327327

328+
- LTX-Video 0.9.8 distilled model is similar to the 0.9.7 variant. It is guidance and timestep-distilled, and similar inference code can be used as above. An improvement of this version is that it supports generating very long videos. Additionally, it supports using tone mapping to improve the quality of the generated video using the `tone_map_compression_ratio` parameter. The default value of `0.6` is recommended.
329+
330+
<details>
331+
<summary>Show example code</summary>
332+
333+
```python
334+
import torch
335+
from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline
336+
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
337+
from diffusers.pipelines.ltx.modeling_latent_upsampler import LTXLatentUpsamplerModel
338+
from diffusers.utils import export_to_video, load_video
339+
340+
pipeline = LTXConditionPipeline.from_pretrained("Lightricks/LTX-Video-0.9.8-13B-distilled", torch_dtype=torch.bfloat16)
341+
# TODO: Update the checkpoint here once updated in LTX org
342+
upsampler = LTXLatentUpsamplerModel.from_pretrained("a-r-r-o-w/LTX-0.9.8-Latent-Upsampler", torch_dtype=torch.bfloat16)
343+
pipe_upsample = LTXLatentUpsamplePipeline(vae=pipeline.vae, latent_upsampler=upsampler).to(torch.bfloat16)
344+
pipeline.to("cuda")
345+
pipe_upsample.to("cuda")
346+
pipeline.vae.enable_tiling()
347+
348+
def round_to_nearest_resolution_acceptable_by_vae(height, width):
349+
height = height - (height % pipeline.vae_spatial_compression_ratio)
350+
width = width - (width % pipeline.vae_spatial_compression_ratio)
351+
return height, width
352+
353+
prompt = """The camera pans over a snow-covered mountain range, revealing a vast expanse of snow-capped peaks and valleys.The mountains are covered in a thick layer of snow, with some areas appearing almost white while others have a slightly darker, almost grayish hue. The peaks are jagged and irregular, with some rising sharply into the sky while others are more rounded. The valleys are deep and narrow, with steep slopes that are also covered in snow. The trees in the foreground are mostly bare, with only a few leaves remaining on their branches. The sky is overcast, with thick clouds obscuring the sun. The overall impression is one of peace and tranquility, with the snow-covered mountains standing as a testament to the power and beauty of nature."""
354+
# prompt = """A woman walks away from a white Jeep parked on a city street at night, then ascends a staircase and knocks on a door. The woman, wearing a dark jacket and jeans, walks away from the Jeep parked on the left side of the street, her back to the camera; she walks at a steady pace, her arms swinging slightly by her sides; the street is dimly lit, with streetlights casting pools of light on the wet pavement; a man in a dark jacket and jeans walks past the Jeep in the opposite direction; the camera follows the woman from behind as she walks up a set of stairs towards a building with a green door; she reaches the top of the stairs and turns left, continuing to walk towards the building; she reaches the door and knocks on it with her right hand; the camera remains stationary, focused on the doorway; the scene is captured in real-life footage."""
355+
negative_prompt = "bright colors, symbols, graffiti, watermarks, worst quality, inconsistent motion, blurry, jittery, distorted"
356+
expected_height, expected_width = 480, 832
357+
downscale_factor = 2 / 3
358+
# num_frames = 161
359+
num_frames = 361
360+
361+
# 1. Generate video at smaller resolution
362+
downscaled_height, downscaled_width = int(expected_height * downscale_factor), int(expected_width * downscale_factor)
363+
downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(downscaled_height, downscaled_width)
364+
latents = pipeline(
365+
prompt=prompt,
366+
negative_prompt=negative_prompt,
367+
width=downscaled_width,
368+
height=downscaled_height,
369+
num_frames=num_frames,
370+
timesteps=[1000, 993, 987, 981, 975, 909, 725, 0.03],
371+
decode_timestep=0.05,
372+
decode_noise_scale=0.025,
373+
image_cond_noise_scale=0.0,
374+
guidance_scale=1.0,
375+
guidance_rescale=0.7,
376+
generator=torch.Generator().manual_seed(0),
377+
output_type="latent",
378+
).frames
379+
380+
# 2. Upscale generated video using latent upsampler with fewer inference steps
381+
# The available latent upsampler upscales the height/width by 2x
382+
upscaled_height, upscaled_width = downscaled_height * 2, downscaled_width * 2
383+
upscaled_latents = pipe_upsample(
384+
latents=latents,
385+
adain_factor=1.0,
386+
tone_map_compression_ratio=0.6,
387+
output_type="latent"
388+
).frames
389+
390+
# 3. Denoise the upscaled video with few steps to improve texture (optional, but recommended)
391+
video = pipeline(
392+
prompt=prompt,
393+
negative_prompt=negative_prompt,
394+
width=upscaled_width,
395+
height=upscaled_height,
396+
num_frames=num_frames,
397+
denoise_strength=0.999, # Effectively, 4 inference steps out of 5
398+
timesteps=[1000, 909, 725, 421, 0],
399+
latents=upscaled_latents,
400+
decode_timestep=0.05,
401+
decode_noise_scale=0.025,
402+
image_cond_noise_scale=0.0,
403+
guidance_scale=1.0,
404+
guidance_rescale=0.7,
405+
generator=torch.Generator().manual_seed(0),
406+
output_type="pil",
407+
).frames[0]
408+
409+
# 4. Downscale the video to the expected resolution
410+
video = [frame.resize((expected_width, expected_height)) for frame in video]
411+
412+
export_to_video(video, "output.mp4", fps=24)
413+
```
414+
415+
</details>
416+
328417
- LTX-Video supports LoRAs with [`~loaders.LTXVideoLoraLoaderMixin.load_lora_weights`].
329418

330419
<details>

scripts/convert_ltx_to_diffusers.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,15 @@ def get_spatial_latent_upsampler_config(version: str) -> Dict[str, Any]:
369369
"spatial_upsample": True,
370370
"temporal_upsample": False,
371371
}
372+
elif version == "0.9.8":
373+
config = {
374+
"in_channels": 128,
375+
"mid_channels": 512,
376+
"num_blocks_per_stage": 4,
377+
"dims": 3,
378+
"spatial_upsample": True,
379+
"temporal_upsample": False,
380+
}
372381
else:
373382
raise ValueError(f"Unsupported version: {version}")
374383
return config
@@ -402,7 +411,7 @@ def get_args():
402411
"--version",
403412
type=str,
404413
default="0.9.0",
405-
choices=["0.9.0", "0.9.1", "0.9.5", "0.9.7"],
414+
choices=["0.9.0", "0.9.1", "0.9.5", "0.9.7", "0.9.8"],
406415
help="Version of the LTX model",
407416
)
408417
return parser.parse_args()

src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,38 @@ def adain_filter_latent(self, latents: torch.Tensor, reference_latents: torch.Te
121121
result = torch.lerp(latents, result, factor)
122122
return result
123123

124+
def tone_map_latents(self, latents: torch.Tensor, compression: float) -> torch.Tensor:
125+
"""
126+
Applies a non-linear tone-mapping function to latent values to reduce their dynamic range in a perceptually
127+
smooth way using a sigmoid-based compression.
128+
129+
This is useful for regularizing high-variance latents or for conditioning outputs during generation, especially
130+
when controlling dynamic behavior with a `compression` factor.
131+
132+
Args:
133+
latents : torch.Tensor
134+
Input latent tensor with arbitrary shape. Expected to be roughly in [-1, 1] or [0, 1] range.
135+
compression : float
136+
Compression strength in the range [0, 1].
137+
- 0.0: No tone-mapping (identity transform)
138+
- 1.0: Full compression effect
139+
140+
Returns:
141+
torch.Tensor
142+
The tone-mapped latent tensor of the same shape as input.
143+
"""
144+
# Remap [0-1] to [0-0.75] and apply sigmoid compression in one shot
145+
scale_factor = compression * 0.75
146+
abs_latents = torch.abs(latents)
147+
148+
# Sigmoid compression: sigmoid shifts large values toward 0.2, small values stay ~1.0
149+
# When scale_factor=0, sigmoid term vanishes, when scale_factor=0.75, full effect
150+
sigmoid_term = torch.sigmoid(4.0 * scale_factor * (abs_latents - 1.0))
151+
scales = 1.0 - 0.8 * scale_factor * sigmoid_term
152+
153+
filtered = latents * scales
154+
return filtered
155+
124156
@staticmethod
125157
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents
126158
def _normalize_latents(
@@ -172,7 +204,7 @@ def disable_vae_tiling(self):
172204
"""
173205
self.vae.disable_tiling()
174206

175-
def check_inputs(self, video, height, width, latents):
207+
def check_inputs(self, video, height, width, latents, tone_map_compression_ratio):
176208
if height % self.vae_spatial_compression_ratio != 0 or width % self.vae_spatial_compression_ratio != 0:
177209
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
178210

@@ -181,6 +213,9 @@ def check_inputs(self, video, height, width, latents):
181213
if video is None and latents is None:
182214
raise ValueError("One of `video` or `latents` has to be provided.")
183215

216+
if not (0 <= tone_map_compression_ratio <= 1):
217+
raise ValueError("`tone_map_compression_ratio` must be in the range [0, 1]")
218+
184219
@torch.no_grad()
185220
def __call__(
186221
self,
@@ -191,6 +226,7 @@ def __call__(
191226
decode_timestep: Union[float, List[float]] = 0.0,
192227
decode_noise_scale: Optional[Union[float, List[float]]] = None,
193228
adain_factor: float = 0.0,
229+
tone_map_compression_ratio: float = 0.0,
194230
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
195231
output_type: Optional[str] = "pil",
196232
return_dict: bool = True,
@@ -200,6 +236,7 @@ def __call__(
200236
height=height,
201237
width=width,
202238
latents=latents,
239+
tone_map_compression_ratio=tone_map_compression_ratio,
203240
)
204241

205242
if video is not None:
@@ -242,6 +279,9 @@ def __call__(
242279
else:
243280
latents = latents_upsampled
244281

282+
if tone_map_compression_ratio > 0.0:
283+
latents = self.tone_map_latents(latents, tone_map_compression_ratio)
284+
245285
if output_type == "latent":
246286
latents = self._normalize_latents(
247287
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor

0 commit comments

Comments
 (0)