Skip to content

Commit 087c835

Browse files
authored
unittest for wan parallel (#78)
1 parent f928511 commit 087c835

File tree

2 files changed

+55
-2
lines changed

2 files changed

+55
-2
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import torch
2+
import torch.multiprocessing as mp
3+
import unittest
4+
import numpy as np
5+
6+
from diffsynth_engine.utils.loader import load_file
7+
from diffsynth_engine.utils.parallel import ParallelModel
8+
from diffsynth_engine.models.wan.wan_vae import WanVideoVAE
9+
from diffsynth_engine import fetch_model
10+
from tests.common.test_case import VideoTestCase
11+
12+
13+
class TestWanVAEParallel(VideoTestCase):
14+
@classmethod
15+
def setUpClass(cls):
16+
mp.set_start_method("spawn")
17+
cls._vae_model_path = fetch_model("muse/wan2.1-vae", path="vae.safetensors")
18+
loaded_state_dict = load_file(cls._vae_model_path)
19+
vae = WanVideoVAE.from_state_dict(loaded_state_dict, parallelism=4)
20+
cls.vae = ParallelModel(vae, cfg_degree=1, sp_ulysses_degree=4, sp_ring_degree=1, tp_degree=1)
21+
cls._input_video = cls.get_input_video("astronaut_320_320.mp4")
22+
23+
@classmethod
24+
def tearDownClass(cls):
25+
del cls.vae
26+
27+
def test_encode_parallel(self):
28+
expected_tensor = self.get_expect_tensor("wan/wan_vae.safetensors")
29+
expected = expected_tensor["encoded"]
30+
video_frames = [
31+
torch.tensor(np.array(frame, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
32+
for frame in self._input_video.frames
33+
]
34+
video_tensor = torch.stack(video_frames, dim=2)
35+
with torch.no_grad():
36+
result = self.vae.encode(video_tensor, device="cuda", tiled=True).cpu()
37+
self.assertTensorEqual(result, expected)
38+
39+
def test_decode_parallel(self):
40+
expected_tensor = self.get_expect_tensor("wan/wan_vae.safetensors")
41+
latent_tensor, expected = expected_tensor["encoded"], expected_tensor["decoded"]
42+
with torch.no_grad():
43+
result = self.vae.decode(latent_tensor, device="cuda", tiled=True)[0].cpu()
44+
self.assertTensorEqual(result, expected)
45+
46+
47+
if __name__ == "__main__":
48+
unittest.main()

tests/test_pipelines/test_wan_video_tp.py renamed to tests/test_pipelines/test_wan_video_parallel.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torch.multiprocessing as mp
12
import unittest
23

34
from tests.common.test_case import VideoTestCase
@@ -8,13 +9,18 @@
89
class TestWanVideoTP(VideoTestCase):
910
@classmethod
1011
def setUpClass(cls):
12+
mp.set_start_method("spawn")
1113
config = WanModelConfig(
1214
model_path=fetch_model("MusePublic/wan2.1-1.3b", path="dit.safetensors"),
1315
t5_path=fetch_model("muse/wan2.1-umt5", path="umt5.safetensors"),
1416
vae_path=fetch_model("muse/wan2.1-vae", path="vae.safetensors"),
1517
)
1618
cls.pipe = WanVideoPipeline.from_pretrained(config, parallelism=4, use_cfg_parallel=True)
1719

20+
@classmethod
21+
def tearDownClass(cls):
22+
del cls.pipe
23+
1824
def test_txt2video(self):
1925
video = self.pipe(
2026
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
@@ -23,8 +29,7 @@ def test_txt2video(self):
2329
width=480,
2430
height=480,
2531
)
26-
self.save_video(video, "wan_tp_t2v.mp4")
27-
del self.pipe
32+
self.save_video(video, "wan_t2v.mp4")
2833

2934

3035
if __name__ == "__main__":

0 commit comments

Comments
 (0)