diff --git a/tests/pipelines/wan/test_wan.py b/tests/pipelines/wan/test_wan.py index 842b9d19b34e..fdb2d298356e 100644 --- a/tests/pipelines/wan/test_wan.py +++ b/tests/pipelines/wan/test_wan.py @@ -15,7 +15,6 @@ import gc import unittest -import numpy as np import torch from transformers import AutoTokenizer, T5EncoderModel @@ -29,9 +28,7 @@ ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import ( - PipelineTesterMixin, -) +from ..test_pipelines_common import PipelineTesterMixin enable_full_determinism() @@ -127,11 +124,15 @@ def test_inference(self): inputs = self.get_dummy_inputs(device) video = pipe(**inputs).frames generated_video = video[0] - self.assertEqual(generated_video.shape, (9, 3, 16, 16)) - expected_video = torch.randn(9, 3, 16, 16) - max_diff = np.abs(generated_video - expected_video).max() - self.assertLessEqual(max_diff, 1e10) + + # fmt: off + expected_slice = torch.tensor([0.4525, 0.452, 0.4485, 0.4534, 0.4524, 0.4529, 0.454, 0.453, 0.5127, 0.5326, 0.5204, 0.5253, 0.5439, 0.5424, 0.5133, 0.5078]) + # fmt: on + + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) @unittest.skip("Test not supported") def test_attention_slicing_forward_pass(self): diff --git a/tests/pipelines/wan/test_wan_image_to_video.py b/tests/pipelines/wan/test_wan_image_to_video.py index 22dfef2eb0b1..6edc0cc882f7 100644 --- a/tests/pipelines/wan/test_wan_image_to_video.py +++ b/tests/pipelines/wan/test_wan_image_to_video.py @@ -14,7 +14,6 @@ import unittest -import numpy as np import torch from PIL import Image from transformers import ( @@ -147,11 +146,15 @@ def test_inference(self): inputs = self.get_dummy_inputs(device) video = pipe(**inputs).frames generated_video = video[0] - self.assertEqual(generated_video.shape, (9, 3, 16, 16)) - expected_video = torch.randn(9, 3, 16, 16) - max_diff = np.abs(generated_video - expected_video).max() - self.assertLessEqual(max_diff, 1e10) + + # fmt: off + expected_slice = torch.tensor([0.4525, 0.4525, 0.4497, 0.4536, 0.452, 0.4529, 0.454, 0.4535, 0.5072, 0.5527, 0.5165, 0.5244, 0.5481, 0.5282, 0.5208, 0.5214]) + # fmt: on + + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) @unittest.skip("Test not supported") def test_attention_slicing_forward_pass(self): @@ -162,7 +165,25 @@ def test_inference_batch_single_identical(self): pass -class WanFLFToVideoPipelineFastTests(WanImageToVideoPipelineFastTests): +class WanFLFToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = WanImageToVideoPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "height", "width"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) vae = AutoencoderKLWan( @@ -247,3 +268,32 @@ def get_dummy_inputs(self, device, seed=0): "output_type": "pt", } return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + + # fmt: off + expected_slice = torch.tensor([0.4531, 0.4527, 0.4498, 0.4542, 0.4526, 0.4527, 0.4534, 0.4534, 0.5061, 0.5185, 0.5283, 0.5181, 0.5309, 0.5365, 0.5113, 0.5244]) + # fmt: on + + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + @unittest.skip("TODO: revisit failing as it requires a very high threshold to pass") + def test_inference_batch_single_identical(self): + pass diff --git a/tests/pipelines/wan/test_wan_video_to_video.py b/tests/pipelines/wan/test_wan_video_to_video.py index 11c748424a30..f4bb0960acee 100644 --- a/tests/pipelines/wan/test_wan_video_to_video.py +++ b/tests/pipelines/wan/test_wan_video_to_video.py @@ -14,7 +14,6 @@ import unittest -import numpy as np import torch from PIL import Image from transformers import AutoTokenizer, T5EncoderModel @@ -123,11 +122,15 @@ def test_inference(self): inputs = self.get_dummy_inputs(device) video = pipe(**inputs).frames generated_video = video[0] - self.assertEqual(generated_video.shape, (17, 3, 16, 16)) - expected_video = torch.randn(17, 3, 16, 16) - max_diff = np.abs(generated_video - expected_video).max() - self.assertLessEqual(max_diff, 1e10) + + # fmt: off + expected_slice = torch.tensor([0.4522, 0.4534, 0.4532, 0.4553, 0.4526, 0.4538, 0.4533, 0.4547, 0.513, 0.5176, 0.5286, 0.4958, 0.4955, 0.5381, 0.5154, 0.5195]) + # fmt:on + + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) @unittest.skip("Test not supported") def test_attention_slicing_forward_pass(self):