Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
6c7e720
Initial implementation of perturbed attn processor for LTX 2.3
dg845 Mar 6, 2026
e90b90a
Update DiT block for LTX 2.3 + add self_attention_mask
dg845 Mar 7, 2026
f768f8d
Add flag to control using perturbed attn processor for now
dg845 Mar 8, 2026
cde6748
Add support for new video upsampling blocks used by LTX-2.3
dg845 Mar 8, 2026
236eb8d
Support LTX-2.3 Big-VGAN V2-style vocoder
dg845 Mar 8, 2026
1e89cb3
Initial implementation of LTX-2.3 vocoder with bandwidth extender
dg845 Mar 8, 2026
5a44adb
Initial support for LTX-2.3 per-modality feature extractor
dg845 Mar 9, 2026
4ff3168
Refactor so that text connectors own all text encoder hidden_states n…
dg845 Mar 9, 2026
835bed6
Fix some bugs for inference
dg845 Mar 9, 2026
19004ef
Fix LTX-2.X DiT block forward pass
dg845 Mar 9, 2026
4dfcfeb
Support prompt timestep embeds and prompt cross attn modulation
dg845 Mar 9, 2026
13292dd
Add LTX-2.3 configs to conversion script
dg845 Mar 10, 2026
0528fde
Support converting LTX-2.3 DiT checkpoints
dg845 Mar 10, 2026
c5e1fcc
Support converting LTX-2.3 Video VAE checkpoints
dg845 Mar 10, 2026
50da4df
Support converting LTX-2.3 Vocoder with bandwidth extender
dg845 Mar 10, 2026
4206280
Support converting LTX-2.3 text connectors
dg845 Mar 10, 2026
e719d32
Don't convert any upsamplers for now
dg845 Mar 10, 2026
fbb50d9
Support self attention mask for LTX2Pipeline
dg845 Mar 10, 2026
de3f869
Fix some inference bugs
dg845 Mar 10, 2026
5056aa8
Support self attn mask and sigmas for LTX-2.3 I2V, Cond pipelines
dg845 Mar 11, 2026
f875031
Support STG and modality isolation guidance for LTX-2.3
dg845 Mar 11, 2026
652d363
make style and make quality
dg845 Mar 11, 2026
d018534
Make audio guidance values default to video values by default
dg845 Mar 11, 2026
c0bb2ef
Update to LTX-2.3 style guidance rescaling
dg845 Mar 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
310 changes: 280 additions & 30 deletions scripts/convert_ltx2_to_diffusers.py

Large diffs are not rendered by default.

88 changes: 68 additions & 20 deletions src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def forward(


# Like LTX 1.0 LTXVideoDownsampler3d, but uses new causal Conv3d
class LTXVideoDownsampler3d(nn.Module):
class LTX2VideoDownsampler3d(nn.Module):
def __init__(
self,
in_channels: int,
Expand Down Expand Up @@ -285,10 +285,11 @@ def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Ten


# Like LTX 1.0 LTXVideoUpsampler3d, but uses new causal Conv3d
class LTXVideoUpsampler3d(nn.Module):
class LTX2VideoUpsampler3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int | None = None,
stride: int | tuple[int, int, int] = 1,
residual: bool = False,
upscale_factor: int = 1,
Expand All @@ -300,7 +301,8 @@ def __init__(
self.residual = residual
self.upscale_factor = upscale_factor

out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
out_channels = out_channels or in_channels
out_channels = (out_channels * stride[0] * stride[1] * stride[2]) // upscale_factor

self.conv = LTX2VideoCausalConv3d(
in_channels=in_channels,
Expand Down Expand Up @@ -408,7 +410,7 @@ def __init__(
)
elif downsample_type == "spatial":
self.downsamplers.append(
LTXVideoDownsampler3d(
LTX2VideoDownsampler3d(
in_channels=in_channels,
out_channels=out_channels,
stride=(1, 2, 2),
Expand All @@ -417,7 +419,7 @@ def __init__(
)
elif downsample_type == "temporal":
self.downsamplers.append(
LTXVideoDownsampler3d(
LTX2VideoDownsampler3d(
in_channels=in_channels,
out_channels=out_channels,
stride=(2, 1, 1),
Expand All @@ -426,7 +428,7 @@ def __init__(
)
elif downsample_type == "spatiotemporal":
self.downsamplers.append(
LTXVideoDownsampler3d(
LTX2VideoDownsampler3d(
in_channels=in_channels,
out_channels=out_channels,
stride=(2, 2, 2),
Expand Down Expand Up @@ -580,6 +582,7 @@ def __init__(
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
spatio_temporal_scale: bool = True,
upsample_type: str = "spatiotemporal",
inject_noise: bool = False,
timestep_conditioning: bool = False,
upsample_residual: bool = False,
Expand Down Expand Up @@ -609,17 +612,38 @@ def __init__(

self.upsamplers = None
if spatio_temporal_scale:
self.upsamplers = nn.ModuleList(
[
LTXVideoUpsampler3d(
out_channels * upscale_factor,
self.upsamplers = nn.ModuleList()

if upsample_type == "spatial":
self.upsamplers.append(
LTX2VideoUpsampler3d(
in_channels=out_channels * upscale_factor,
stride=(1, 2, 2),
residual=upsample_residual,
upscale_factor=upscale_factor,
spatial_padding_mode=spatial_padding_mode,
)
)
elif upsample_type == "temporal":
self.upsamplers.append(
LTX2VideoUpsampler3d(
in_channels=out_channels * upscale_factor,
stride=(2, 1, 1),
residual=upsample_residual,
upscale_factor=upscale_factor,
spatial_padding_mode=spatial_padding_mode,
)
)
elif upsample_type == "spatiotemporal":
self.upsamplers.append(
LTX2VideoUpsampler3d(
in_channels=out_channels * upscale_factor,
stride=(2, 2, 2),
residual=upsample_residual,
upscale_factor=upscale_factor,
spatial_padding_mode=spatial_padding_mode,
)
]
)
)

resnets = []
for _ in range(num_layers):
Expand Down Expand Up @@ -716,7 +740,7 @@ def __init__(
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
),
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True),
spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True, True),
layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2),
downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
patch_size: int = 4,
Expand All @@ -726,6 +750,9 @@ def __init__(
spatial_padding_mode: str = "zeros",
):
super().__init__()
num_encoder_blocks = len(layers_per_block)
if isinstance(spatio_temporal_scaling, bool):
spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_encoder_blocks - 1)

self.patch_size = patch_size
self.patch_size_t = patch_size_t
Expand Down Expand Up @@ -860,19 +887,27 @@ def __init__(
in_channels: int = 128,
out_channels: int = 3,
block_out_channels: tuple[int, ...] = (256, 512, 1024),
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True),
spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True),
layers_per_block: tuple[int, ...] = (5, 5, 5, 5),
upsample_type: tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"),
patch_size: int = 4,
patch_size_t: int = 1,
resnet_norm_eps: float = 1e-6,
is_causal: bool = False,
inject_noise: tuple[bool, ...] = (False, False, False),
inject_noise: bool | tuple[bool, ...] = (False, False, False),
timestep_conditioning: bool = False,
upsample_residual: tuple[bool, ...] = (True, True, True),
upsample_residual: bool | tuple[bool, ...] = (True, True, True),
upsample_factor: tuple[bool, ...] = (2, 2, 2),
spatial_padding_mode: str = "reflect",
) -> None:
super().__init__()
num_decoder_blocks = len(layers_per_block)
if isinstance(spatio_temporal_scaling, bool):
spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_decoder_blocks - 1)
if isinstance(inject_noise, bool):
inject_noise = (inject_noise,) * num_decoder_blocks
if isinstance(upsample_residual, bool):
upsample_residual = (upsample_residual,) * (num_decoder_blocks - 1)

self.patch_size = patch_size
self.patch_size_t = patch_size_t
Expand Down Expand Up @@ -917,6 +952,7 @@ def __init__(
num_layers=layers_per_block[i + 1],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
upsample_type=upsample_type[i],
inject_noise=inject_noise[i + 1],
timestep_conditioning=timestep_conditioning,
upsample_residual=upsample_residual[i],
Expand Down Expand Up @@ -1058,11 +1094,12 @@ def __init__(
decoder_block_out_channels: tuple[int, ...] = (256, 512, 1024),
layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2),
decoder_layers_per_block: tuple[int, ...] = (5, 5, 5, 5),
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True),
decoder_spatio_temporal_scaling: tuple[bool, ...] = (True, True, True),
decoder_inject_noise: tuple[bool, ...] = (False, False, False, False),
spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True, True),
decoder_spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True),
decoder_inject_noise: bool | tuple[bool, ...] = (False, False, False, False),
downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
upsample_residual: tuple[bool, ...] = (True, True, True),
upsample_type: tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"),
upsample_residual: bool | tuple[bool, ...] = (True, True, True),
upsample_factor: tuple[int, ...] = (2, 2, 2),
timestep_conditioning: bool = False,
patch_size: int = 4,
Expand All @@ -1077,6 +1114,16 @@ def __init__(
temporal_compression_ratio: int = None,
) -> None:
super().__init__()
num_encoder_blocks = len(layers_per_block)
num_decoder_blocks = len(decoder_layers_per_block)
if isinstance(spatio_temporal_scaling, bool):
spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_encoder_blocks - 1)
if isinstance(decoder_spatio_temporal_scaling, bool):
decoder_spatio_temporal_scaling = (decoder_spatio_temporal_scaling,) * (num_decoder_blocks - 1)
if isinstance(decoder_inject_noise, bool):
decoder_inject_noise = (decoder_inject_noise,) * num_decoder_blocks
if isinstance(upsample_residual, bool):
upsample_residual = (upsample_residual,) * (num_decoder_blocks - 1)

self.encoder = LTX2VideoEncoder3d(
in_channels=in_channels,
Expand All @@ -1098,6 +1145,7 @@ def __init__(
block_out_channels=decoder_block_out_channels,
spatio_temporal_scaling=decoder_spatio_temporal_scaling,
layers_per_block=decoder_layers_per_block,
upsample_type=upsample_type,
patch_size=patch_size,
patch_size_t=patch_size_t,
resnet_norm_eps=resnet_norm_eps,
Expand Down
Loading
Loading