Skip to content

Commit 503ca81

Browse files
committed
update
1 parent 3c8b67b commit 503ca81

File tree

3 files changed

+25
-18
lines changed

3 files changed

+25
-18
lines changed

tests/pipelines/wan/test_wan.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import gc
1616
import unittest
1717

18-
import numpy as np
1918
import torch
2019
from transformers import AutoTokenizer, T5EncoderModel
2120

@@ -29,9 +28,7 @@
2928
)
3029

3130
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
32-
from ..test_pipelines_common import (
33-
PipelineTesterMixin,
34-
)
31+
from ..test_pipelines_common import PipelineTesterMixin
3532

3633

3734
enable_full_determinism()
@@ -127,11 +124,15 @@ def test_inference(self):
127124
inputs = self.get_dummy_inputs(device)
128125
video = pipe(**inputs).frames
129126
generated_video = video[0]
130-
131127
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
132-
expected_video = torch.randn(9, 3, 16, 16)
133-
max_diff = np.abs(generated_video - expected_video).max()
134-
self.assertLessEqual(max_diff, 1e10)
128+
129+
# fmt: off
130+
expected_slice = torch.tensor([0.4525, 0.452, 0.4485, 0.4534, 0.4524, 0.4529, 0.454, 0.453, 0.5127, 0.5326, 0.5204, 0.5253, 0.5439, 0.5424, 0.5133, 0.5078])
131+
# fmt: on
132+
133+
generated_slice = generated_video.flatten()
134+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
135+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
135136

136137
@unittest.skip("Test not supported")
137138
def test_attention_slicing_forward_pass(self):

tests/pipelines/wan/test_wan_image_to_video.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import unittest
1616

17-
import numpy as np
1817
import torch
1918
from PIL import Image
2019
from transformers import (
@@ -147,11 +146,15 @@ def test_inference(self):
147146
inputs = self.get_dummy_inputs(device)
148147
video = pipe(**inputs).frames
149148
generated_video = video[0]
150-
151149
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
152-
expected_video = torch.randn(9, 3, 16, 16)
153-
max_diff = np.abs(generated_video - expected_video).max()
154-
self.assertLessEqual(max_diff, 1e10)
150+
151+
# fmt: off
152+
expected_slice = torch.tensor([0.4525, 0.4525, 0.4497, 0.4536, 0.452, 0.4529, 0.454, 0.4535, 0.5072, 0.5527, 0.5165, 0.5244, 0.5481, 0.5282, 0.5208, 0.5214])
153+
# fmt: on
154+
155+
generated_slice = generated_video.flatten()
156+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
157+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
155158

156159
@unittest.skip("Test not supported")
157160
def test_attention_slicing_forward_pass(self):

tests/pipelines/wan/test_wan_video_to_video.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import unittest
1616

17-
import numpy as np
1817
import torch
1918
from PIL import Image
2019
from transformers import AutoTokenizer, T5EncoderModel
@@ -123,11 +122,15 @@ def test_inference(self):
123122
inputs = self.get_dummy_inputs(device)
124123
video = pipe(**inputs).frames
125124
generated_video = video[0]
126-
127125
self.assertEqual(generated_video.shape, (17, 3, 16, 16))
128-
expected_video = torch.randn(17, 3, 16, 16)
129-
max_diff = np.abs(generated_video - expected_video).max()
130-
self.assertLessEqual(max_diff, 1e10)
126+
127+
# fmt: off
128+
expected_slice = torch.tensor([0.4522, 0.4534, 0.4532, 0.4553, 0.4526, 0.4538, 0.4533, 0.4547, 0.513, 0.5176, 0.5286, 0.4958, 0.4955, 0.5381, 0.5154, 0.5195])
129+
# fmt:on
130+
131+
generated_slice = generated_video.flatten()
132+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
133+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
131134

132135
@unittest.skip("Test not supported")
133136
def test_attention_slicing_forward_pass(self):

0 commit comments

Comments
 (0)