Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 71 additions & 43 deletions tests/models/transformers/test_models_transformer_ltx.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -13,26 +12,58 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch

from diffusers import LTXVideoTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor

from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
from ..testing_utils import (
BaseModelTesterConfig,
MemoryTesterMixin,
ModelTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)


enable_full_determinism()


class LTXTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = LTXVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
class LTXTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return LTXVideoTransformer3DModel

@property
def output_shape(self) -> tuple[int, int]:
return (512, 4)

@property
def dummy_input(self):
def input_shape(self) -> tuple[int, int]:
return (512, 4)

@property
def main_input_name(self) -> str:
return "hidden_states"

@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)

def get_init_dict(self):
return {
"in_channels": 4,
"out_channels": 4,
"num_attention_heads": 2,
"attention_head_dim": 8,
"cross_attention_dim": 16,
"num_layers": 1,
"qk_norm": "rms_norm_across_heads",
"caption_channels": 16,
}

def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 2
num_channels = 4
num_frames = 2
Expand All @@ -41,50 +72,47 @@ def dummy_input(self):
embedding_dim = 16
sequence_length = 16

hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)

return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
"encoder_attention_mask": encoder_attention_mask,
"hidden_states": randn_tensor(
(batch_size, num_frames * height * width, num_channels),
generator=self.generator,
device=torch_device,
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"encoder_attention_mask": torch.ones((batch_size, sequence_length)).bool().to(torch_device),
"num_frames": num_frames,
"height": height,
"width": width,
}

@property
def input_shape(self):
return (512, 4)

@property
def output_shape(self):
return (512, 4)
class TestLTXTransformer(LTXTransformerTesterConfig, ModelTesterMixin):
"""Core model tests for LTX Video Transformer."""

def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 4,
"out_channels": 4,
"num_attention_heads": 2,
"attention_head_dim": 8,
"cross_attention_dim": 16,
"num_layers": 1,
"qk_norm": "rms_norm_across_heads",
"caption_channels": 16,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict

class TestLTXTransformerMemory(LTXTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for LTX Video Transformer."""


class TestLTXTransformerTraining(LTXTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for LTX Video Transformer."""

def test_gradient_checkpointing_is_applied(self):
expected_set = {"LTXVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
super().test_gradient_checkpointing_is_applied(expected_set={"LTXVideoTransformer3DModel"})


class TestLTXTransformerCompile(LTXTransformerTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for LTX Video Transformer."""


# TODO: Add pretrained_model_name_or_path once a tiny LTX model is available on the Hub
# class TestLTXTransformerBitsAndBytes(LTXTransformerTesterConfig, BitsAndBytesTesterMixin):
# """BitsAndBytes quantization tests for LTX Video Transformer."""

class LTXTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = LTXVideoTransformer3DModel

def prepare_init_args_and_inputs_for_common(self):
return LTXTransformerTests().prepare_init_args_and_inputs_for_common()
# TODO: Add pretrained_model_name_or_path once a tiny LTX model is available on the Hub
# class TestLTXTransformerTorchAo(LTXTransformerTesterConfig, TorchAoTesterMixin):
# """TorchAo quantization tests for LTX Video Transformer."""
Loading