From ad579a797559311d84c4bea18b77baaa691888d7 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 18 Jul 2025 17:24:08 +0200 Subject: [PATCH] update --- .../hunyuan_video/test_hunyuan_image2video.py | 15 ++++++++++---- .../test_hunyuan_skyreels_image2video.py | 15 ++++++++++---- .../hunyuan_video/test_hunyuan_video.py | 20 +++++++++++-------- .../test_hunyuan_video_framepack.py | 15 ++++++++++---- 4 files changed, 45 insertions(+), 20 deletions(-) diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py b/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py index 6a4e3a89319a..82281f28bc84 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py @@ -229,12 +229,19 @@ def test_inference(self): inputs = self.get_dummy_inputs(device) video = pipe(**inputs).frames generated_video = video[0] - # NOTE: The expected video has 4 lesser frames because they are dropped in the pipeline self.assertEqual(generated_video.shape, (5, 3, 16, 16)) - expected_video = torch.randn(5, 3, 16, 16) - max_diff = np.abs(generated_video - expected_video).max() - self.assertLessEqual(max_diff, 1e10) + + # fmt: off + expected_slice = torch.tensor([0.444, 0.479, 0.4485, 0.5752, 0.3539, 0.1548, 0.2706, 0.3593, 0.5323, 0.6635, 0.6795, 0.5255, 0.5091, 0.345, 0.4276, 0.4128]) + # fmt: on + + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue( + torch.allclose(generated_slice, expected_slice, atol=1e-3), + "The generated video does not match the expected slice.", + ) def test_callback_inputs(self): sig = inspect.signature(self.pipeline_class.__call__) diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_skyreels_image2video.py b/tests/pipelines/hunyuan_video/test_hunyuan_skyreels_image2video.py index 94d3c3739f97..fad159c06b0f 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_skyreels_image2video.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_skyreels_image2video.py @@ -192,11 +192,18 @@ def test_inference(self): inputs = self.get_dummy_inputs(device) video = pipe(**inputs).frames generated_video = video[0] - self.assertEqual(generated_video.shape, (9, 3, 16, 16)) - expected_video = torch.randn(9, 3, 16, 16) - max_diff = np.abs(generated_video - expected_video).max() - self.assertLessEqual(max_diff, 1e10) + + # fmt: off + expected_slice = torch.tensor([0.5832, 0.5498, 0.4839, 0.4744, 0.4515, 0.4832, 0.496, 0.563, 0.5918, 0.5979, 0.5101, 0.6168, 0.6613, 0.536, 0.55, 0.5775]) + # fmt: on + + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue( + torch.allclose(generated_slice, expected_slice, atol=1e-3), + "The generated video does not match the expected slice.", + ) def test_callback_inputs(self): sig = inspect.signature(self.pipeline_class.__call__) diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py index 10101af75cee..26ec861522a9 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_video.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py @@ -26,10 +26,7 @@ HunyuanVideoPipeline, HunyuanVideoTransformer3DModel, ) -from diffusers.utils.testing_utils import ( - enable_full_determinism, - torch_device, -) +from diffusers.utils.testing_utils import enable_full_determinism, torch_device from ..test_pipelines_common import ( FasterCacheTesterMixin, @@ -206,11 +203,18 @@ def test_inference(self): inputs = self.get_dummy_inputs(device) video = pipe(**inputs).frames generated_video = video[0] - self.assertEqual(generated_video.shape, (9, 3, 16, 16)) - expected_video = torch.randn(9, 3, 16, 16) - max_diff = np.abs(generated_video - expected_video).max() - self.assertLessEqual(max_diff, 1e10) + + # fmt: off + expected_slice = torch.tensor([0.3946, 0.4649, 0.3196, 0.4569, 0.3312, 0.3687, 0.3216, 0.3972, 0.4469, 0.3888, 0.3929, 0.3802, 0.3479, 0.3888, 0.3825, 0.3542]) + # fmt: on + + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue( + torch.allclose(generated_slice, expected_slice, atol=1e-3), + "The generated video does not match the expected slice.", + ) def test_callback_inputs(self): sig = inspect.signature(self.pipeline_class.__call__) diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py b/tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py index 9f685d34c933..297c3df45a10 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py @@ -227,11 +227,18 @@ def test_inference(self): inputs = self.get_dummy_inputs(device) video = pipe(**inputs).frames generated_video = video[0] - self.assertEqual(generated_video.shape, (13, 3, 32, 32)) - expected_video = torch.randn(13, 3, 32, 32) - max_diff = np.abs(generated_video - expected_video).max() - self.assertLessEqual(max_diff, 1e10) + + # fmt: off + expected_slice = torch.tensor([0.363, 0.3384, 0.3426, 0.3512, 0.3372, 0.3276, 0.417, 0.4061, 0.5221, 0.467, 0.4813, 0.4556, 0.4107, 0.3945, 0.4049, 0.4551]) + # fmt: on + + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue( + torch.allclose(generated_slice, expected_slice, atol=1e-3), + "The generated video does not match the expected slice.", + ) def test_callback_inputs(self): sig = inspect.signature(self.pipeline_class.__call__)