From b28e9204f7bf2f40224246be57e6788f0c551250 Mon Sep 17 00:00:00 2001 From: DN6 Date: Wed, 11 Mar 2026 13:32:16 +0530 Subject: [PATCH 1/2] update --- .../test_models_transformer_ltx.py | 106 ++++--- .../test_models_transformer_ltx2.py | 261 ++++++------------ 2 files changed, 152 insertions(+), 215 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_ltx.py b/tests/models/transformers/test_models_transformer_ltx.py index e912463bbf6a..8b52cf6e780c 100644 --- a/tests/models/transformers/test_models_transformer_ltx.py +++ b/tests/models/transformers/test_models_transformer_ltx.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2025 HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,26 +12,58 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - import torch from diffusers import LTXVideoTransformer3DModel +from diffusers.utils.torch_utils import randn_tensor from ...testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin +from ..testing_utils import ( + BaseModelTesterConfig, + MemoryTesterMixin, + ModelTesterMixin, + TorchCompileTesterMixin, + TrainingTesterMixin, +) enable_full_determinism() -class LTXTransformerTests(ModelTesterMixin, unittest.TestCase): - model_class = LTXVideoTransformer3DModel - main_input_name = "hidden_states" - uses_custom_attn_processor = True +class LTXTransformerTesterConfig(BaseModelTesterConfig): + @property + def model_class(self): + return LTXVideoTransformer3DModel + + @property + def output_shape(self) -> tuple[int, int]: + return (512, 4) + + @property + def input_shape(self) -> tuple[int, int]: + return (512, 4) + + @property + def main_input_name(self) -> str: + return "hidden_states" @property - def dummy_input(self): + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self): + return { + "in_channels": 4, + "out_channels": 4, + "num_attention_heads": 2, + "attention_head_dim": 8, + "cross_attention_dim": 16, + "num_layers": 1, + "qk_norm": "rms_norm_across_heads", + "caption_channels": 16, + } + + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: batch_size = 2 num_channels = 4 num_frames = 2 @@ -41,50 +72,37 @@ def dummy_input(self): embedding_dim = 16 sequence_length = 16 - hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device) - encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) - encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device) - timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) - return { - "hidden_states": hidden_states, - "encoder_hidden_states": encoder_hidden_states, - "timestep": timestep, - "encoder_attention_mask": encoder_attention_mask, + "hidden_states": randn_tensor( + (batch_size, num_frames * height * width, num_channels), + generator=self.generator, + device=torch_device, + ), + "encoder_hidden_states": randn_tensor( + (batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device + ), + "timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device), + "encoder_attention_mask": torch.ones((batch_size, sequence_length)).bool().to(torch_device), "num_frames": num_frames, "height": height, "width": width, } - @property - def input_shape(self): - return (512, 4) - @property - def output_shape(self): - return (512, 4) +class TestLTXTransformer(LTXTransformerTesterConfig, ModelTesterMixin): + """Core model tests for LTX Video Transformer.""" - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "in_channels": 4, - "out_channels": 4, - "num_attention_heads": 2, - "attention_head_dim": 8, - "cross_attention_dim": 16, - "num_layers": 1, - "qk_norm": "rms_norm_across_heads", - "caption_channels": 16, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - def test_gradient_checkpointing_is_applied(self): - expected_set = {"LTXVideoTransformer3DModel"} - super().test_gradient_checkpointing_is_applied(expected_set=expected_set) +class TestLTXTransformerMemory(LTXTransformerTesterConfig, MemoryTesterMixin): + """Memory optimization tests for LTX Video Transformer.""" + +class TestLTXTransformerTraining(LTXTransformerTesterConfig, TrainingTesterMixin): + """Training tests for LTX Video Transformer.""" + + def test_gradient_checkpointing_is_applied(self): + super().test_gradient_checkpointing_is_applied(expected_set={"LTXVideoTransformer3DModel"}) -class LTXTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): - model_class = LTXVideoTransformer3DModel - def prepare_init_args_and_inputs_for_common(self): - return LTXTransformerTests().prepare_init_args_and_inputs_for_common() +class TestLTXTransformerCompile(LTXTransformerTesterConfig, TorchCompileTesterMixin): + """Torch compile tests for LTX Video Transformer.""" diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py index af9ef0623891..80519843e276 100644 --- a/tests/models/transformers/test_models_transformer_ltx2.py +++ b/tests/models/transformers/test_models_transformer_ltx2.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2025 HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,77 +12,49 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - +import pytest import torch from diffusers import LTX2VideoTransformer3DModel +from diffusers.utils.torch_utils import randn_tensor from ...testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + MemoryTesterMixin, + ModelTesterMixin, + TorchCompileTesterMixin, + TrainingTesterMixin, +) enable_full_determinism() -class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase): - model_class = LTX2VideoTransformer3DModel - main_input_name = "hidden_states" - uses_custom_attn_processor = True - +class LTX2TransformerTesterConfig(BaseModelTesterConfig): @property - def dummy_input(self): - # Common - batch_size = 2 - - # Video - num_frames = 2 - num_channels = 4 - height = 16 - width = 16 - - # Audio - audio_num_frames = 9 - audio_num_channels = 2 - num_mel_bins = 2 - - # Text - embedding_dim = 16 - sequence_length = 16 - - hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device) - audio_hidden_states = torch.randn((batch_size, audio_num_frames, audio_num_channels * num_mel_bins)).to( - torch_device - ) - encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) - audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) - encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device) - timestep = torch.rand((batch_size,)).to(torch_device) * 1000 - - return { - "hidden_states": hidden_states, - "audio_hidden_states": audio_hidden_states, - "encoder_hidden_states": encoder_hidden_states, - "audio_encoder_hidden_states": audio_encoder_hidden_states, - "timestep": timestep, - "encoder_attention_mask": encoder_attention_mask, - "num_frames": num_frames, - "height": height, - "width": width, - "audio_num_frames": audio_num_frames, - "fps": 25.0, - } + def model_class(self): + return LTX2VideoTransformer3DModel @property - def input_shape(self): + def output_shape(self) -> tuple[int, int]: return (512, 4) @property - def output_shape(self): + def input_shape(self) -> tuple[int, int]: return (512, 4) - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + @property + def main_input_name(self) -> str: + return "hidden_states" + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) + + def get_init_dict(self): + return { "in_channels": 4, "out_channels": 4, "patch_size": 1, @@ -101,122 +72,70 @@ def prepare_init_args_and_inputs_for_common(self): "caption_channels": 16, "rope_double_precision": False, } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + + def get_dummy_inputs(self) -> dict[str, torch.Tensor]: + batch_size = 2 + num_frames = 2 + num_channels = 4 + height = 16 + width = 16 + audio_num_frames = 9 + audio_num_channels = 2 + num_mel_bins = 2 + embedding_dim = 16 + sequence_length = 16 + + return { + "hidden_states": randn_tensor( + (batch_size, num_frames * height * width, num_channels), + generator=self.generator, + device=torch_device, + ), + "audio_hidden_states": randn_tensor( + (batch_size, audio_num_frames, audio_num_channels * num_mel_bins), + generator=self.generator, + device=torch_device, + ), + "encoder_hidden_states": randn_tensor( + (batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device + ), + "audio_encoder_hidden_states": randn_tensor( + (batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device + ), + "timestep": (randn_tensor((batch_size,), generator=self.generator, device=torch_device).abs() * 1000), + "encoder_attention_mask": torch.ones((batch_size, sequence_length)).bool().to(torch_device), + "num_frames": num_frames, + "height": height, + "width": width, + "audio_num_frames": audio_num_frames, + "fps": 25.0, + } + + +class TestLTX2Transformer(LTX2TransformerTesterConfig, ModelTesterMixin): + """Core model tests for LTX2 Video Transformer.""" + + +class TestLTX2TransformerMemory(LTX2TransformerTesterConfig, MemoryTesterMixin): + """Memory optimization tests for LTX2 Video Transformer.""" + + +class TestLTX2TransformerTraining(LTX2TransformerTesterConfig, TrainingTesterMixin): + """Training tests for LTX2 Video Transformer.""" def test_gradient_checkpointing_is_applied(self): - expected_set = {"LTX2VideoTransformer3DModel"} - super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - - # def test_ltx2_consistency(self, seed=0, dtype=torch.float32): - # torch.manual_seed(seed) - # init_dict, _ = self.prepare_init_args_and_inputs_for_common() - - # # Calculate dummy inputs in a custom manner to ensure compatibility with original code - # batch_size = 2 - # num_frames = 9 - # latent_frames = 2 - # text_embedding_dim = 16 - # text_seq_len = 16 - # fps = 25.0 - # sampling_rate = 16000.0 - # hop_length = 160.0 - - # sigma = torch.rand((1,), generator=torch.manual_seed(seed), dtype=dtype, device="cpu") * 1000 - # timestep = (sigma * torch.ones((batch_size,), dtype=dtype, device="cpu")).to(device=torch_device) - - # num_channels = 4 - # latent_height = 4 - # latent_width = 4 - # hidden_states = torch.randn( - # (batch_size, num_channels, latent_frames, latent_height, latent_width), - # generator=torch.manual_seed(seed), - # dtype=dtype, - # device="cpu", - # ) - # # Patchify video latents (with patch_size (1, 1, 1)) - # hidden_states = hidden_states.reshape(batch_size, -1, latent_frames, 1, latent_height, 1, latent_width, 1) - # hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) - # encoder_hidden_states = torch.randn( - # (batch_size, text_seq_len, text_embedding_dim), - # generator=torch.manual_seed(seed), - # dtype=dtype, - # device="cpu", - # ) - - # audio_num_channels = 2 - # num_mel_bins = 2 - # latent_length = int((sampling_rate / hop_length / 4) * (num_frames / fps)) - # audio_hidden_states = torch.randn( - # (batch_size, audio_num_channels, latent_length, num_mel_bins), - # generator=torch.manual_seed(seed), - # dtype=dtype, - # device="cpu", - # ) - # # Patchify audio latents - # audio_hidden_states = audio_hidden_states.transpose(1, 2).flatten(2, 3) - # audio_encoder_hidden_states = torch.randn( - # (batch_size, text_seq_len, text_embedding_dim), - # generator=torch.manual_seed(seed), - # dtype=dtype, - # device="cpu", - # ) - - # inputs_dict = { - # "hidden_states": hidden_states.to(device=torch_device), - # "audio_hidden_states": audio_hidden_states.to(device=torch_device), - # "encoder_hidden_states": encoder_hidden_states.to(device=torch_device), - # "audio_encoder_hidden_states": audio_encoder_hidden_states.to(device=torch_device), - # "timestep": timestep, - # "num_frames": latent_frames, - # "height": latent_height, - # "width": latent_width, - # "audio_num_frames": num_frames, - # "fps": 25.0, - # } - - # model = self.model_class.from_pretrained( - # "diffusers-internal-dev/dummy-ltx2", - # subfolder="transformer", - # device_map="cpu", - # ) - # # torch.manual_seed(seed) - # # model = self.model_class(**init_dict) - # model.to(torch_device) - # model.eval() - - # with attention_backend("native"): - # with torch.no_grad(): - # output = model(**inputs_dict) - - # video_output, audio_output = output.to_tuple() - - # self.assertIsNotNone(video_output) - # self.assertIsNotNone(audio_output) - - # # input & output have to have the same shape - # video_expected_shape = (batch_size, latent_frames * latent_height * latent_width, num_channels) - # self.assertEqual(video_output.shape, video_expected_shape, "Video input and output shapes do not match") - # audio_expected_shape = (batch_size, latent_length, audio_num_channels * num_mel_bins) - # self.assertEqual(audio_output.shape, audio_expected_shape, "Audio input and output shapes do not match") - - # # Check against expected slice - # # fmt: off - # video_expected_slice = torch.tensor([0.4783, 1.6954, -1.2092, 0.1762, 0.7801, 1.2025, -1.4525, -0.2721, 0.3354, 1.9144, -1.5546, 0.0831, 0.4391, 1.7012, -1.7373, -0.2676]) - # audio_expected_slice = torch.tensor([-0.4236, 0.4750, 0.3901, -0.4339, -0.2782, 0.4357, 0.4526, -0.3927, -0.0980, 0.4870, 0.3964, -0.3169, -0.3974, 0.4408, 0.3809, -0.4692]) - # # fmt: on - - # video_output_flat = video_output.cpu().flatten().float() - # video_generated_slice = torch.cat([video_output_flat[:8], video_output_flat[-8:]]) - # self.assertTrue(torch.allclose(video_generated_slice, video_expected_slice, atol=1e-4)) - - # audio_output_flat = audio_output.cpu().flatten().float() - # audio_generated_slice = torch.cat([audio_output_flat[:8], audio_output_flat[-8:]]) - # self.assertTrue(torch.allclose(audio_generated_slice, audio_expected_slice, atol=1e-4)) - - -class LTX2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): - model_class = LTX2VideoTransformer3DModel - - def prepare_init_args_and_inputs_for_common(self): - return LTX2TransformerTests().prepare_init_args_and_inputs_for_common() + super().test_gradient_checkpointing_is_applied(expected_set={"LTX2VideoTransformer3DModel"}) + + +class TestLTX2TransformerAttention(LTX2TransformerTesterConfig, AttentionTesterMixin): + """Attention processor tests for LTX2 Video Transformer.""" + + @pytest.mark.skip( + "LTX2Attention does not set is_cross_attention, so fuse_projections tries to fuse Q+K+V together even for cross-attention modules with different input dimensions." + ) + def test_fuse_unfuse_qkv_projections(self, atol=1e-3, rtol=0): + pass + + +class TestLTX2TransformerCompile(LTX2TransformerTesterConfig, TorchCompileTesterMixin): + """Torch compile tests for LTX2 Video Transformer.""" From 52bd90a3b36c722638f3c622128e9c9c7a321db4 Mon Sep 17 00:00:00 2001 From: DN6 Date: Wed, 11 Mar 2026 13:40:14 +0530 Subject: [PATCH 2/2] update --- .../models/transformers/test_models_transformer_ltx.py | 10 ++++++++++ .../transformers/test_models_transformer_ltx2.py | 10 ++++++++++ 2 files changed, 20 insertions(+) diff --git a/tests/models/transformers/test_models_transformer_ltx.py b/tests/models/transformers/test_models_transformer_ltx.py index 8b52cf6e780c..95bebd8c3335 100644 --- a/tests/models/transformers/test_models_transformer_ltx.py +++ b/tests/models/transformers/test_models_transformer_ltx.py @@ -106,3 +106,13 @@ def test_gradient_checkpointing_is_applied(self): class TestLTXTransformerCompile(LTXTransformerTesterConfig, TorchCompileTesterMixin): """Torch compile tests for LTX Video Transformer.""" + + +# TODO: Add pretrained_model_name_or_path once a tiny LTX model is available on the Hub +# class TestLTXTransformerBitsAndBytes(LTXTransformerTesterConfig, BitsAndBytesTesterMixin): +# """BitsAndBytes quantization tests for LTX Video Transformer.""" + + +# TODO: Add pretrained_model_name_or_path once a tiny LTX model is available on the Hub +# class TestLTXTransformerTorchAo(LTXTransformerTesterConfig, TorchAoTesterMixin): +# """TorchAo quantization tests for LTX Video Transformer.""" diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py index 80519843e276..e0e858bb6916 100644 --- a/tests/models/transformers/test_models_transformer_ltx2.py +++ b/tests/models/transformers/test_models_transformer_ltx2.py @@ -139,3 +139,13 @@ def test_fuse_unfuse_qkv_projections(self, atol=1e-3, rtol=0): class TestLTX2TransformerCompile(LTX2TransformerTesterConfig, TorchCompileTesterMixin): """Torch compile tests for LTX2 Video Transformer.""" + + +# TODO: Add pretrained_model_name_or_path once a tiny LTX2 model is available on the Hub +# class TestLTX2TransformerBitsAndBytes(LTX2TransformerTesterConfig, BitsAndBytesTesterMixin): +# """BitsAndBytes quantization tests for LTX2 Video Transformer.""" + + +# TODO: Add pretrained_model_name_or_path once a tiny LTX2 model is available on the Hub +# class TestLTX2TransformerTorchAo(LTX2TransformerTesterConfig, TorchAoTesterMixin): +# """TorchAo quantization tests for LTX2 Video Transformer."""