Skip to content

Commit c18b63c

Browse files
committed
fix
1 parent ff27aea commit c18b63c

File tree

3 files changed

+49
-5
lines changed

3 files changed

+49
-5
lines changed

tests/pipelines/wan/test_wan.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ def test_inference(self):
132132

133133
generated_slice = generated_video.flatten()
134134
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
135-
print("txt2video:", [round(x, 4) for x in generated_slice.tolist()])
136135
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
137136

138137
@unittest.skip("Test not supported")

tests/pipelines/wan/test_wan_image_to_video.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,11 @@ def test_inference(self):
149149
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
150150

151151
# fmt: off
152-
expected_slice = torch.tensor([0.4531, 0.4527, 0.4498, 0.4542, 0.4526, 0.4527, 0.4534, 0.4534, 0.5061, 0.5185, 0.5283, 0.5181, 0.5309, 0.5365, 0.5113, 0.5244])
152+
expected_slice = torch.tensor([0.4525, 0.4525, 0.4497, 0.4536, 0.452, 0.4529, 0.454, 0.4535, 0.5072, 0.5527, 0.5165, 0.5244, 0.5481, 0.5282, 0.5208, 0.5214])
153153
# fmt: on
154154

155155
generated_slice = generated_video.flatten()
156156
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
157-
print("image2video:", [round(x, 4) for x in generated_slice.tolist()])
158157
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
159158

160159
@unittest.skip("Test not supported")
@@ -166,7 +165,25 @@ def test_inference_batch_single_identical(self):
166165
pass
167166

168167

169-
class WanFLFToVideoPipelineFastTests(WanImageToVideoPipelineFastTests):
168+
class WanFLFToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
169+
pipeline_class = WanImageToVideoPipeline
170+
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "height", "width"}
171+
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
172+
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
173+
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
174+
required_optional_params = frozenset(
175+
[
176+
"num_inference_steps",
177+
"generator",
178+
"latents",
179+
"return_dict",
180+
"callback_on_step_end",
181+
"callback_on_step_end_tensor_inputs",
182+
]
183+
)
184+
test_xformers_attention = False
185+
supports_dduf = False
186+
170187
def get_dummy_components(self):
171188
torch.manual_seed(0)
172189
vae = AutoencoderKLWan(
@@ -251,3 +268,32 @@ def get_dummy_inputs(self, device, seed=0):
251268
"output_type": "pt",
252269
}
253270
return inputs
271+
272+
def test_inference(self):
273+
device = "cpu"
274+
275+
components = self.get_dummy_components()
276+
pipe = self.pipeline_class(**components)
277+
pipe.to(device)
278+
pipe.set_progress_bar_config(disable=None)
279+
280+
inputs = self.get_dummy_inputs(device)
281+
video = pipe(**inputs).frames
282+
generated_video = video[0]
283+
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
284+
285+
# fmt: off
286+
expected_slice = torch.tensor([0.4531, 0.4527, 0.4498, 0.4542, 0.4526, 0.4527, 0.4534, 0.4534, 0.5061, 0.5185, 0.5283, 0.5181, 0.5309, 0.5365, 0.5113, 0.5244])
287+
# fmt: on
288+
289+
generated_slice = generated_video.flatten()
290+
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
291+
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
292+
293+
@unittest.skip("Test not supported")
294+
def test_attention_slicing_forward_pass(self):
295+
pass
296+
297+
@unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
298+
def test_inference_batch_single_identical(self):
299+
pass

tests/pipelines/wan/test_wan_video_to_video.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ def test_inference(self):
130130

131131
generated_slice = generated_video.flatten()
132132
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
133-
print("video2video:", [round(x, 4) for x in generated_slice.tolist()])
134133
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
135134

136135
@unittest.skip("Test not supported")

0 commit comments

Comments
 (0)