Skip to content

Commit cf3a77e

Browse files
committed
revert the timestep_scale_multiplier change
1 parent 21502d9 commit cf3a77e

File tree

3 files changed

+34
-18
lines changed

3 files changed

+34
-18
lines changed

scripts/convert_ltx_to_diffusers.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
105105
"per_channel_statistics.mean-of-means": remove_keys_,
106106
"per_channel_statistics.mean-of-stds": remove_keys_,
107107
"model.diffusion_model": remove_keys_,
108-
"decoder.timestep_scale_multiplier": remove_keys_,
109108
}
110109

111110

@@ -271,7 +270,6 @@ def get_vae_config(version: str) -> Dict[str, Any]:
271270
"decoder_causal": False,
272271
"spatial_compression_ratio": 32,
273272
"temporal_compression_ratio": 8,
274-
"timestep_scale_multiplier": 1000.0,
275273
}
276274
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
277275
return config

src/diffusers/models/autoencoders/autoencoder_kl_ltx.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -921,14 +921,12 @@ def __init__(
921921
timestep_conditioning: bool = False,
922922
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
923923
upsample_factor: Tuple[bool, ...] = (1, 1, 1, 1),
924-
timestep_scale_multiplier: float = 1.0,
925924
) -> None:
926925
super().__init__()
927926

928927
self.patch_size = patch_size
929928
self.patch_size_t = patch_size_t
930929
self.out_channels = out_channels * patch_size**2
931-
self.timestep_scale_multiplier = timestep_scale_multiplier
932930

933931
block_out_channels = tuple(reversed(block_out_channels))
934932
spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling))
@@ -983,7 +981,9 @@ def __init__(
983981
# timestep embedding
984982
self.time_embedder = None
985983
self.scale_shift_table = None
984+
self.timestep_scale_multiplier = None
986985
if timestep_conditioning:
986+
self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32))
987987
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
988988
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
989989

@@ -992,7 +992,7 @@ def __init__(
992992
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
993993
hidden_states = self.conv_in(hidden_states)
994994

995-
if temb is not None:
995+
if self.timestep_scale_multiplier is not None:
996996
temb = temb * self.timestep_scale_multiplier
997997

998998
if torch.is_grad_enabled() and self.gradient_checkpointing:
@@ -1107,7 +1107,6 @@ def __init__(
11071107
decoder_causal: bool = False,
11081108
spatial_compression_ratio: int = None,
11091109
temporal_compression_ratio: int = None,
1110-
timestep_scale_multiplier: float = 1.0,
11111110
) -> None:
11121111
super().__init__()
11131112

@@ -1138,7 +1137,6 @@ def __init__(
11381137
inject_noise=decoder_inject_noise,
11391138
upsample_residual=upsample_residual,
11401139
upsample_factor=upsample_factor,
1141-
timestep_scale_multiplier=timestep_scale_multiplier,
11421140
)
11431141

11441142
latents_mean = torch.zeros((latent_channels,), requires_grad=False)

src/diffusers/pipelines/ltx/pipeline_ltx_condition.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,26 +46,46 @@
4646
Examples:
4747
```py
4848
>>> import torch
49-
>>> from diffusers import LTXConditionPipeline
50-
>>> from diffusers.utils import export_to_video, load_image
51-
52-
>>> pipe = LTXConditionPipeline.from_pretrained("YiYiXu/ltx-95", torch_dtype=torch.bfloat16)
49+
>>> from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXConditionPipeline, LTXVideoCondition
50+
>>> from diffusers.utils import export_to_video, load_video, load_image
51+
>>>
52+
>>> pipe = LTXConditionPipeline.from_pretrained("Lightricks/LTX-Video-0.9.1", torch_dtype=torch.bfloat16)
5353
>>> pipe.to("cuda")
54+
>>>
55+
>>> # Load input image and video
56+
>>> video = load_video(
57+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4"
58+
... )
5459
>>> image = load_image(
55-
... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png"
60+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input.jpg"
61+
... )
62+
>>>
63+
>>> # Create conditioning objects
64+
>>> condition1 = LTXVideoCondition(
65+
... image=image,
66+
... frame_index=0,
5667
... )
57-
>>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background. Flames engulf the structure, with smoke billowing into the air. Firefighters in protective gear rush to the scene, a fire truck labeled '38' visible behind them. The girl's neutral expression contrasts sharply with the chaos of the fire, creating a poignant and emotionally charged scene."
68+
>>> condition2 = LTXVideoCondition(
69+
... video=video,
70+
... frame_index=80,
71+
... )
72+
>>>
73+
>>> prompt = "The video depicts a long, straight highway stretching into the distance, flanked by metal guardrails. The road is divided into multiple lanes, with a few vehicles visible in the far distance. The surrounding landscape features dry, grassy fields on one side and rolling hills on the other. The sky is mostly clear with a few scattered clouds, suggesting a bright, sunny day. And then the camera switch to a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
5874
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
59-
75+
>>>
76+
>>> # Generate video
77+
>>> generator = torch.Generator("cuda").manual_seed(0)
6078
>>> video = pipe(
61-
... image=image,
79+
... conditions=[condition1, condition2],
6280
... prompt=prompt,
6381
... negative_prompt=negative_prompt,
64-
... width=704,
65-
... height=480,
82+
... width=768,
83+
... height=512,
6684
... num_frames=161,
67-
... num_inference_steps=50,
85+
... num_inference_steps=40,
86+
... generator=generator,
6887
... ).frames[0]
88+
>>>
6989
>>> export_to_video(video, "output.mp4", fps=24)
7090
```
7191
"""

0 commit comments

Comments
 (0)