Skip to content

Commit c8dc8c4

Browse files
authored
[feat] support pab and multi-gpu in open-sora-plan v1.2.0 (#223)
* update parallel * update pab * update test * empty cache * update tests
1 parent 868c489 commit c8dc8c4

File tree

13 files changed

+246
-300
lines changed

13 files changed

+246
-300
lines changed

eval/pab/experiments/opensora_plan.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from utils import generate_func, read_prompt_list
22

3-
from videosys import OpenSoraPlanConfig, OpenSoraPlanPABConfig, VideoSysEngine
3+
from videosys import OpenSoraPlanConfig, OpenSoraPlanV110PABConfig, VideoSysEngine
44

55

66
def eval_base(prompt_list):
@@ -10,7 +10,7 @@ def eval_base(prompt_list):
1010

1111

1212
def eval_pab1(prompt_list):
13-
pab_config = OpenSoraPlanPABConfig(
13+
pab_config = OpenSoraPlanV110PABConfig(
1414
spatial_gap=2,
1515
temporal_gap=4,
1616
cross_gap=6,
@@ -21,7 +21,7 @@ def eval_pab1(prompt_list):
2121

2222

2323
def eval_pab2(prompt_list):
24-
pab_config = OpenSoraPlanPABConfig(
24+
pab_config = OpenSoraPlanV110PABConfig(
2525
spatial_gap=3,
2626
temporal_gap=5,
2727
cross_gap=7,
@@ -32,7 +32,7 @@ def eval_pab2(prompt_list):
3232

3333

3434
def eval_pab3(prompt_list):
35-
pab_config = OpenSoraPlanPABConfig(
35+
pab_config = OpenSoraPlanV110PABConfig(
3636
spatial_gap=5,
3737
temporal_gap=7,
3838
cross_gap=9,

examples/open_sora_plan/sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ def run_base():
55
# open-sora-plan v1.2.0
66
# transformer_type (len, res): 93x480p 93x720p 29x480p 29x720p
77
# change num_gpus for multi-gpu inference
8-
config = OpenSoraPlanConfig(version="v120", transformer_type="93x480p", num_gpus=1)
8+
config = OpenSoraPlanConfig(version="v120", transformer_type="29x480p", num_gpus=1)
99
engine = VideoSysEngine(config)
1010

1111
prompt = "Sunset over the sea."
Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,40 @@
11
import pytest
22

33
from videosys import CogVideoXConfig, VideoSysEngine
4+
from videosys.utils.test import empty_cache
45

56

67
@pytest.mark.parametrize("num_gpus", [1, 2])
7-
def test_base(num_gpus):
8-
config = CogVideoXConfig(num_gpus=num_gpus)
8+
@pytest.mark.parametrize("model_path", ["THUDM/CogVideoX-2b", "THUDM/CogVideoX-5b"])
9+
@empty_cache
10+
def test_base(num_gpus, model_path):
11+
config = CogVideoXConfig(model_path=model_path, num_gpus=num_gpus)
912
engine = VideoSysEngine(config)
1013

1114
prompt = "Sunset over the sea."
1215
video = engine.generate(prompt, seed=0).video[0]
13-
engine.save_video(video, f"./test_outputs/{prompt}_cogvideo_{num_gpus}.mp4")
16+
engine.save_video(video, f"./test_outputs/{prompt}_cogvideo_{model_path.replace('/', '_')}_base_{num_gpus}.mp4")
1417

1518

1619
@pytest.mark.parametrize("num_gpus", [1])
17-
def test_pab(num_gpus):
18-
config = CogVideoXConfig(num_gpus=num_gpus, enable_pab=True)
20+
@pytest.mark.parametrize("model_path", ["THUDM/CogVideoX-2b"])
21+
@empty_cache
22+
def test_pab(num_gpus, model_path):
23+
config = CogVideoXConfig(model_path=model_path, num_gpus=num_gpus, enable_pab=True)
1924
engine = VideoSysEngine(config)
2025

2126
prompt = "Sunset over the sea."
2227
video = engine.generate(prompt, seed=0).video[0]
23-
engine.save_video(video, f"./test_outputs/{prompt}_cogvideo_pab_{num_gpus}.mp4")
28+
engine.save_video(video, f"./test_outputs/{prompt}_cogvideo_{model_path.replace('/', '_')}_pab_{num_gpus}.mp4")
2429

2530

2631
@pytest.mark.parametrize("num_gpus", [1])
27-
def test_low_mem(num_gpus):
28-
config = CogVideoXConfig(num_gpus=num_gpus, cpu_offload=True, vae_tiling=True)
32+
@pytest.mark.parametrize("model_path", ["THUDM/CogVideoX-2b"])
33+
@empty_cache
34+
def test_low_mem(num_gpus, model_path):
35+
config = CogVideoXConfig(model_path=model_path, num_gpus=num_gpus, cpu_offload=True, vae_tiling=True)
2936
engine = VideoSysEngine(config)
3037

3138
prompt = "Sunset over the sea."
3239
video = engine.generate(prompt, seed=0).video[0]
33-
engine.save_video(video, f"./test_outputs/{prompt}_cogvideo_low_mem_{num_gpus}.mp4")
40+
engine.save_video(video, f"./test_outputs/{prompt}_cogvideo_{model_path.replace('/', '_')}_low_mem_{num_gpus}.mp4")

tests/pipelines/latte/test_latte.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import pytest
22

33
from videosys import LatteConfig, VideoSysEngine
4+
from videosys.utils.test import empty_cache
45

56

67
@pytest.mark.parametrize("num_gpus", [1, 2])
8+
@empty_cache
79
def test_base(num_gpus):
810
config = LatteConfig(num_gpus=num_gpus)
911
engine = VideoSysEngine(config)
@@ -14,6 +16,7 @@ def test_base(num_gpus):
1416

1517

1618
@pytest.mark.parametrize("num_gpus", [1])
19+
@empty_cache
1720
def test_pab(num_gpus):
1821
config = LatteConfig(num_gpus=num_gpus, enable_pab=True)
1922
engine = VideoSysEngine(config)
@@ -24,6 +27,7 @@ def test_pab(num_gpus):
2427

2528

2629
@pytest.mark.parametrize("num_gpus", [1])
30+
@empty_cache
2731
def test_low_mem(num_gpus):
2832
config = LatteConfig(num_gpus=num_gpus, cpu_offload=True)
2933
engine = VideoSysEngine(config)

tests/pipelines/open_sora/test_open_sora.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import pytest
22

33
from videosys import OpenSoraConfig, VideoSysEngine
4+
from videosys.utils.test import empty_cache
45

56

67
@pytest.mark.parametrize("num_gpus", [1, 2])
8+
@empty_cache
79
def test_base(num_gpus):
810
config = OpenSoraConfig(num_gpus=num_gpus)
911
engine = VideoSysEngine(config)
@@ -14,6 +16,7 @@ def test_base(num_gpus):
1416

1517

1618
@pytest.mark.parametrize("num_gpus", [1])
19+
@empty_cache
1720
def test_pab(num_gpus):
1821
config = OpenSoraConfig(num_gpus=num_gpus, enable_pab=True)
1922
engine = VideoSysEngine(config)
@@ -24,6 +27,7 @@ def test_pab(num_gpus):
2427

2528

2629
@pytest.mark.parametrize("num_gpus", [1])
30+
@empty_cache
2731
def test_low_mem(num_gpus):
2832
config = OpenSoraConfig(num_gpus=num_gpus, cpu_offload=True, tiling_size=1)
2933
engine = VideoSysEngine(config)
Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,42 @@
11
import pytest
22

33
from videosys import OpenSoraPlanConfig, VideoSysEngine
4+
from videosys.utils.test import empty_cache
45

56

67
@pytest.mark.parametrize("num_gpus", [1, 2])
7-
def test_base(num_gpus):
8-
config = OpenSoraPlanConfig(num_gpus=num_gpus)
8+
@pytest.mark.parametrize("model", [("v120", "29x480p")])
9+
@empty_cache
10+
def test_base(num_gpus, model):
11+
config = OpenSoraPlanConfig(version=model[0], transformer_type=model[1], num_gpus=num_gpus)
912
engine = VideoSysEngine(config)
1013

1114
prompt = "Sunset over the sea."
1215
video = engine.generate(prompt, seed=0).video[0]
13-
engine.save_video(video, f"./test_outputs/{prompt}_open_sora_plan_{num_gpus}.mp4")
16+
engine.save_video(video, f"./test_outputs/{prompt}_open_sora_plan_{model[0]}_{model[1]}_{num_gpus}.mp4")
1417

1518

1619
@pytest.mark.parametrize("num_gpus", [1])
17-
def test_pab(num_gpus):
18-
config = OpenSoraPlanConfig(num_gpus=num_gpus, enable_pab=True)
20+
@pytest.mark.parametrize("model", [("v120", "29x480p")])
21+
@empty_cache
22+
def test_pab(num_gpus, model):
23+
config = OpenSoraPlanConfig(version=model[0], transformer_type=model[1], num_gpus=num_gpus, enable_pab=True)
1924
engine = VideoSysEngine(config)
2025

2126
prompt = "Sunset over the sea."
2227
video = engine.generate(prompt, seed=0).video[0]
23-
engine.save_video(video, f"./test_outputs/{prompt}_open_sora_plan_pab_{num_gpus}.mp4")
28+
engine.save_video(video, f"./test_outputs/{prompt}_open_sora_plan_{model[0]}_{model[1]}_pab_{num_gpus}.mp4")
2429

2530

2631
@pytest.mark.parametrize("num_gpus", [1])
27-
def test_low_mem(num_gpus):
28-
config = OpenSoraPlanConfig(num_gpus=num_gpus, cpu_offload=True, enable_tiling=True)
32+
@pytest.mark.parametrize("model", [("v120", "29x480p")])
33+
@empty_cache
34+
def test_low_mem(num_gpus, model):
35+
config = OpenSoraPlanConfig(
36+
version=model[0], transformer_type=model[1], num_gpus=num_gpus, cpu_offload=True, enable_tiling=True
37+
)
2938
engine = VideoSysEngine(config)
3039

3140
prompt = "Sunset over the sea."
3241
video = engine.generate(prompt, seed=0).video[0]
33-
engine.save_video(video, f"./test_outputs/{prompt}_open_sora_plan_low_mem_{num_gpus}.mp4")
42+
engine.save_video(video, f"./test_outputs/{prompt}_open_sora_plan_{model[0]}_{model[1]}_low_mem_{num_gpus}.mp4")
Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,44 @@
11
import pytest
22

33
from videosys import VchitectConfig, VideoSysEngine
4+
from videosys.utils.test import empty_cache
45

56

67
@pytest.mark.parametrize("num_gpus", [1, 2])
7-
def test_base(num_gpus):
8-
config = VchitectConfig(num_gpus=num_gpus)
8+
@pytest.mark.parametrize("model_path", ["Vchitect/Vchitect-2.0-2B"])
9+
@empty_cache
10+
def test_base(num_gpus, model_path):
11+
config = VchitectConfig(model_path=model_path, num_gpus=num_gpus)
912
engine = VideoSysEngine(config)
1013

1114
prompt = "Sunset over the sea."
1215
video = engine.generate(prompt, seed=0).video[0]
13-
engine.save_video(video, f"./test_outputs/{prompt}_vchitect_{num_gpus}.mp4")
16+
engine.save_video(video, f"./test_outputs/{prompt}_vchitect_{model_path.replace('/', '_')}_base_{num_gpus}.mp4")
1417

1518

1619
@pytest.mark.parametrize("num_gpus", [1])
17-
def test_pab(num_gpus):
18-
config = VchitectConfig(num_gpus=num_gpus, enable_pab=True)
20+
@pytest.mark.parametrize("model_path", ["Vchitect/Vchitect-2.0-2B"])
21+
@empty_cache
22+
def test_pab(num_gpus, model_path):
23+
config = VchitectConfig(model_path=model_path, num_gpus=num_gpus, enable_pab=True)
1924
engine = VideoSysEngine(config)
2025

2126
prompt = "Sunset over the sea."
2227
video = engine.generate(prompt, seed=0).video[0]
23-
engine.save_video(video, f"./test_outputs/{prompt}_vchitect_pab_{num_gpus}.mp4")
28+
engine.save_video(video, f"./test_outputs/{prompt}_vchitect_{model_path.replace('/', '_')}_pab_{num_gpus}.mp4")
2429

2530

2631
@pytest.mark.parametrize("num_gpus", [1])
27-
def test_low_mem(num_gpus):
28-
config = VchitectConfig(num_gpus=num_gpus, cpu_offload=True)
32+
@pytest.mark.parametrize("model_path", ["Vchitect/Vchitect-2.0-2B"])
33+
@empty_cache
34+
def test_low_mem(num_gpus, model_path):
35+
config = VchitectConfig(
36+
model_path=model_path,
37+
num_gpus=num_gpus,
38+
cpu_offload=True,
39+
)
2940
engine = VideoSysEngine(config)
3041

3142
prompt = "Sunset over the sea."
3243
video = engine.generate(prompt, seed=0).video[0]
33-
engine.save_video(video, f"./test_outputs/{prompt}_vchitect_low_mem_{num_gpus}.mp4")
44+
engine.save_video(video, f"./test_outputs/{prompt}_vchitect_{model_path.replace('/', '_')}_low_mem_{num_gpus}.mp4")

videosys/__init__.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,20 @@
33
from .pipelines.cogvideox import CogVideoXConfig, CogVideoXPABConfig, CogVideoXPipeline
44
from .pipelines.latte import LatteConfig, LattePABConfig, LattePipeline
55
from .pipelines.open_sora import OpenSoraConfig, OpenSoraPABConfig, OpenSoraPipeline
6-
from .pipelines.open_sora_plan import OpenSoraPlanConfig, OpenSoraPlanPABConfig, OpenSoraPlanPipeline
7-
from .pipelines.vchitect import VchitectConfig, VchitectXLPipeline
6+
from .pipelines.open_sora_plan import (
7+
OpenSoraPlanConfig,
8+
OpenSoraPlanPipeline,
9+
OpenSoraPlanV110PABConfig,
10+
OpenSoraPlanV120PABConfig,
11+
)
12+
from .pipelines.vchitect import VchitectConfig, VchitectPABConfig, VchitectXLPipeline
813

914
__all__ = [
1015
"initialize",
1116
"VideoSysEngine",
1217
"LattePipeline", "LatteConfig", "LattePABConfig",
13-
"OpenSoraPlanPipeline", "OpenSoraPlanConfig", "OpenSoraPlanPABConfig",
18+
"OpenSoraPlanPipeline", "OpenSoraPlanConfig", "OpenSoraPlanV110PABConfig", "OpenSoraPlanV120PABConfig",
1419
"OpenSoraPipeline", "OpenSoraConfig", "OpenSoraPABConfig",
1520
"CogVideoXPipeline", "CogVideoXConfig", "CogVideoXPABConfig",
16-
"VchitectXLPipeline", "VchitectConfig",
21+
"VchitectXLPipeline", "VchitectConfig", "VchitectPABConfig"
1722
] # fmt: skip

0 commit comments

Comments
 (0)