Skip to content

Commit 93af686

Browse files
committed
Pipeline: Enhance inference pipelines with new features
* Adaptive normalization after latent upsampling * CFG Star Rescale * Varying STG/CFG parameters per step * Support skipping the initial and/or the final diffusion steps * CRF compression for image condition (useful for getting more motion in image-to-video)
1 parent cb6f842 commit 93af686

File tree

10 files changed

+220
-43
lines changed

10 files changed

+220
-43
lines changed

configs/ltxv-13b-0.9.7-dev.yaml

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
pipeline_type: multi-scale
32
checkpoint_path: "ltxv-13b-0.9.7-dev.safetensors"
43
downscale_factor: 0.6666666
@@ -14,20 +13,22 @@ prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-P
1413
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
1514
stochastic_sampling: false
1615

17-
1816
first_pass:
19-
guidance_scale: [3]
20-
stg_scale: [1]
21-
rescaling_scale: [0.7]
22-
guidance_timesteps: [1.0]
23-
skip_block_list: [19] # [[1], [1,2], [1,2,3], [27], [28], [28]]
17+
guidance_scale: [1, 1, 6, 8, 6, 1, 1]
18+
stg_scale: [0, 0, 4, 4, 4, 2, 1]
19+
rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1]
20+
guidance_timesteps: [1.0, 0.996, 0.9933, 0.9850, 0.9767, 0.9008, 0.6180]
21+
skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
2422
num_inference_steps: 30
23+
skip_final_inference_steps: 3
24+
cfg_star_rescale: true
2525

2626
second_pass:
27-
guidance_scale: [3]
27+
guidance_scale: [1]
2828
stg_scale: [1]
29-
rescaling_scale: [0.7]
29+
rescaling_scale: [1]
3030
guidance_timesteps: [1.0]
31-
skip_block_list: [19] # [[1], [1,2], [1,2,3], [27], [28], [28]]
32-
num_inference_steps: 10
33-
strength: 0.85
31+
skip_block_list: [27]
32+
num_inference_steps: 30
33+
skip_initial_inference_steps: 17
34+
cfg_star_rescale: true

configs/ltxv-2b-0.9.1.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
pipeline_type: base
2+
checkpoint_path: "ltx-video-2b-v0.9.1.safetensors"
3+
guidance_scale: 3
4+
stg_scale: 1
5+
rescaling_scale: 0.7
6+
skip_block_list: [19]
7+
num_inference_steps: 40
8+
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
9+
decode_timestep: 0.05
10+
decode_noise_scale: 0.025
11+
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
12+
precision: "bfloat16"
13+
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
14+
prompt_enhancement_words_threshold: 120
15+
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
16+
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
17+
stochastic_sampling: false

configs/ltxv-2b-0.9.5.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
pipeline_type: base
2+
checkpoint_path: "ltx-video-2b-v0.9.5.safetensors"
3+
guidance_scale: 3
4+
stg_scale: 1
5+
rescaling_scale: 0.7
6+
skip_block_list: [19]
7+
num_inference_steps: 40
8+
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
9+
decode_timestep: 0.05
10+
decode_noise_scale: 0.025
11+
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
12+
precision: "bfloat16"
13+
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
14+
prompt_enhancement_words_threshold: 120
15+
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
16+
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
17+
stochastic_sampling: false

configs/ltxv-2b-0.9.6-distilled.yaml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
pipeline_type: base
22
checkpoint_path: "ltxv-2b-0.9.6-distilled-04-25.safetensors"
3-
guidance_scale: 3
4-
stg_scale: 1
5-
rescaling_scale: 0.7
6-
skip_block_list: [19]
3+
guidance_scale: 1
4+
stg_scale: 0
5+
rescaling_scale: 1
76
num_inference_steps: 8
87
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
98
decode_timestep: 0.05

configs/ltxv-2b-0.9.yaml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
pipeline_type: base
2+
checkpoint_path: "ltx-video-2b-v0.9.safetensors"
3+
guidance_scale: 3
4+
stg_scale: 1
5+
rescaling_scale: 0.7
6+
skip_block_list: [19]
7+
num_inference_steps: 40
8+
stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
9+
decode_timestep: 0.05
10+
decode_noise_scale: 0.025
11+
text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
12+
precision: "bfloat16"
13+
sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
14+
prompt_enhancement_words_threshold: 120
15+
prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
16+
prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
17+
stochastic_sampling: false

inference.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import json
1212
import numpy as np
1313
import torch
14+
import cv2
1415
from safetensors import safe_open
1516
from PIL import Image
1617
from transformers import (
@@ -35,6 +36,7 @@
3536
from ltx_video.schedulers.rf import RectifiedFlowScheduler
3637
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
3738
from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
39+
import ltx_video.pipelines.crf_compressor as crf_compressor
3840

3941
MAX_HEIGHT = 720
4042
MAX_WIDTH = 1280
@@ -96,7 +98,12 @@ def load_image_to_tensor_with_resize_and_crop(
9698
image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
9799
if not just_crop:
98100
image = image.resize((target_width, target_height))
99-
frame_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float()
101+
102+
image = np.array(image)
103+
image = cv2.GaussianBlur(image, (3, 3), 0)
104+
frame_tensor = torch.from_numpy(image).float()
105+
frame_tensor = crf_compressor.compress(frame_tensor / 255.0) * 255.0
106+
frame_tensor = frame_tensor.permute(2, 0, 1)
100107
frame_tensor = (frame_tensor / 127.5) - 1.0
101108
# Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
102109
return frame_tensor.unsqueeze(0).unsqueeze(2)
@@ -266,13 +273,6 @@ def main():
266273
help="Path to the input video (or imaage) to be modified using the video-to-video pipeline",
267274
)
268275

269-
parser.add_argument(
270-
"--strength",
271-
type=float,
272-
default=1.0,
273-
help="Editing strength (noising level) for video-to-video pipeline.",
274-
)
275-
276276
# Conditioning arguments
277277
parser.add_argument(
278278
"--conditioning_media_paths",
@@ -407,7 +407,6 @@ def infer(
407407
negative_prompt: str,
408408
offload_to_cpu: bool,
409409
input_media_path: Optional[str] = None,
410-
strength: Optional[float] = 1.0,
411410
conditioning_media_paths: Optional[List[str]] = None,
412411
conditioning_strengths: Optional[List[float]] = None,
413412
conditioning_start_frames: Optional[List[int]] = None,
@@ -614,7 +613,6 @@ def infer(
614613
frame_rate=frame_rate,
615614
**sample,
616615
media_items=media_item,
617-
strength=strength,
618616
conditioning_items=conditioning_items,
619617
is_video=True,
620618
vae_per_channel_normalize=True,
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import av
2+
import torch
3+
import io
4+
import numpy as np
5+
6+
7+
def _encode_single_frame(output_file, image_array: np.ndarray, crf):
8+
container = av.open(output_file, "w", format="mp4")
9+
try:
10+
stream = container.add_stream(
11+
"libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
12+
)
13+
stream.height = image_array.shape[0]
14+
stream.width = image_array.shape[1]
15+
av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(
16+
format="yuv420p"
17+
)
18+
container.mux(stream.encode(av_frame))
19+
container.mux(stream.encode())
20+
finally:
21+
container.close()
22+
23+
24+
def _decode_single_frame(video_file):
25+
container = av.open(video_file)
26+
try:
27+
stream = next(s for s in container.streams if s.type == "video")
28+
frame = next(container.decode(stream))
29+
finally:
30+
container.close()
31+
return frame.to_ndarray(format="rgb24")
32+
33+
34+
def compress(image: torch.Tensor, crf=29):
35+
if crf == 0:
36+
return image
37+
38+
image_array = (
39+
(image[: (image.shape[0] // 2) * 2, : (image.shape[1] // 2) * 2] * 255.0)
40+
.byte()
41+
.cpu()
42+
.numpy()
43+
)
44+
with io.BytesIO() as output_file:
45+
_encode_single_frame(output_file, image_array, crf)
46+
video_bytes = output_file.getvalue()
47+
with io.BytesIO(video_bytes) as video_file:
48+
image_array = _decode_single_frame(video_file)
49+
tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
50+
return tensor

0 commit comments

Comments
 (0)