Skip to content

Commit cd7ac04

Browse files
authored
support vchitect and fix some bugs (#216)
support vchitect-2.0 including inference, multi-gpu, pab, low memory improve cpu offload add more modules to offload to save more memory adjust vae forward in open sora and open sora plan to suit cpu offload adjust offload step to be more efficient fix pab step problem sampling steps for Open Sora #211
1 parent a2b71b6 commit cd7ac04

25 files changed

+2525
-85
lines changed

README.md

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ pip install -e .
5050

5151
VideoSys supports many diffusion models with our various acceleration techniques, enabling these models to run faster and consume less memory.
5252

53-
<b>You can find all available models and their supported acceleration techniques in the following table. Click `Doc` to see how to use them.</b>
53+
<b>You can find all available models and their supported acceleration techniques in the following table. Click `Code` to see how to use them.</b>
5454

5555
<table>
5656
<tr>
@@ -65,20 +65,20 @@ VideoSys supports many diffusion models with our various acceleration techniques
6565
<th><a href="https://github.com/NUS-HPC-AI-Lab/VideoSys?tab=readme-ov-file#pyramid-attention-broadcast-pab-blogdoc">PAB</a></th>
6666
</tr>
6767
<tr>
68-
<td>Open-Sora [<a href="https://github.com/hpcaitech/Open-Sora">source</a>]</td>
69-
<td align="center">🟡</td>
68+
<td>Vchitect [<a href="https://github.com/Vchitect/Vchitect-2.0">source</a>]</td>
69+
<td align="center">/</td>
7070
<td align="center">✅</td>
7171
<td align="center">✅</td>
7272
<td align="center">✅</td>
73-
<td align="center"><a href="./examples/open_sora/sample.py">Code</a></td>
73+
<td align="center"><a href="./examples/vchitect/sample.py">Code</a></td>
7474
</tr>
7575
<tr>
76-
<td>Open-Sora-Plan [<a href="https://github.com/PKU-YuanGroup/Open-Sora-Plan">source</a>]</td>
76+
<td>CogVideoX [<a href="https://github.com/THUDM/CogVideo">source</a>]</td>
7777
<td align="center">/</td>
7878
<td align="center">✅</td>
79+
<td align="center">/</td>
7980
<td align="center">✅</td>
80-
<td align="center">✅</td>
81-
<td align="center"><a href="./examples/open_sora_plan/sample.py">Code</a></td>
81+
<td align="center"><a href="./examples/cogvideox/sample.py">Code</a></td>
8282
</tr>
8383
<tr>
8484
<td>Latte [<a href="https://github.com/Vchitect/Latte">source</a>]</td>
@@ -89,12 +89,20 @@ VideoSys supports many diffusion models with our various acceleration techniques
8989
<td align="center"><a href="./examples/latte/sample.py">Code</a></td>
9090
</tr>
9191
<tr>
92-
<td>CogVideoX [<a href="https://github.com/THUDM/CogVideo">source</a>]</td>
92+
<td>Open-Sora-Plan [<a href="https://github.com/PKU-YuanGroup/Open-Sora-Plan">source</a>]</td>
9393
<td align="center">/</td>
9494
<td align="center">✅</td>
95-
<td align="center">/</td>
9695
<td align="center">✅</td>
97-
<td align="center"><a href="./examples/cogvideox/sample.py">Code</a></td>
96+
<td align="center">✅</td>
97+
<td align="center"><a href="./examples/open_sora_plan/sample.py">Code</a></td>
98+
</tr>
99+
<tr>
100+
<td>Open-Sora [<a href="https://github.com/hpcaitech/Open-Sora">source</a>]</td>
101+
<td align="center">🟡</td>
102+
<td align="center">✅</td>
103+
<td align="center">✅</td>
104+
<td align="center">✅</td>
105+
<td align="center"><a href="./examples/open_sora/sample.py">Code</a></td>
98106
</tr>
99107
</table>
100108

examples/vchitect/sample.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from videosys import VchitectConfig, VideoSysEngine
2+
3+
4+
def run_base():
5+
# change num_gpus for multi-gpu inference
6+
config = VchitectConfig("Vchitect/Vchitect-2.0-2B", num_gpus=1)
7+
engine = VideoSysEngine(config)
8+
9+
prompt = "Sunset over the sea."
10+
# seed=-1 means random seed. >0 means fixed seed.
11+
# WxH: 480x288 624x352 432x240 768x432
12+
video = engine.generate(
13+
prompt=prompt,
14+
negative_prompt="",
15+
num_inference_steps=100,
16+
guidance_scale=7.5,
17+
width=480,
18+
height=288,
19+
frames=40,
20+
seed=0,
21+
).video[0]
22+
engine.save_video(video, f"./outputs/{prompt}.mp4")
23+
24+
25+
def run_pab():
26+
config = VchitectConfig("Vchitect/Vchitect-2.0-2B", enable_pab=True)
27+
engine = VideoSysEngine(config)
28+
29+
prompt = "Sunset over the sea."
30+
video = engine.generate(prompt).video[0]
31+
engine.save_video(video, f"./outputs/{prompt}.mp4")
32+
33+
34+
def run_low_mem():
35+
config = VchitectConfig("Vchitect/Vchitect-2.0-2B", cpu_offload=True)
36+
engine = VideoSysEngine(config)
37+
38+
prompt = "Sunset over the sea."
39+
video = engine.generate(prompt).video[0]
40+
engine.save_video(video, f"./outputs/{prompt}.mp4")
41+
42+
43+
if __name__ == "__main__":
44+
run_base()
45+
# run_pab()
46+
# run_low_mem()

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
accelerate>0.17.0
12
bs4
23
click
34
colossalai

tests/examples/test_sample.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
import examples.latte.sample as latte
1212
import examples.open_sora.sample as open_sora
1313
import examples.open_sora_plan.sample as open_sora_plan
14+
import examples.vchitect.sample as vchitect
1415

15-
files = [cogvideox, latte, open_sora, open_sora_plan]
16+
files = [cogvideox, latte, open_sora, open_sora_plan, vchitect]
1617
members = []
1718

1819
for file in files:

tests/pipelines/vchitect/__init__.py

Whitespace-only changes.
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import pytest
2+
3+
from videosys import VchitectConfig, VideoSysEngine
4+
5+
6+
@pytest.mark.parametrize("num_gpus", [1, 2])
7+
def test_base(num_gpus):
8+
config = VchitectConfig(num_gpus=num_gpus)
9+
engine = VideoSysEngine(config)
10+
11+
prompt = "Sunset over the sea."
12+
video = engine.generate(prompt, seed=0).video[0]
13+
engine.save_video(video, f"./test_outputs/{prompt}_vchitect_{num_gpus}.mp4")
14+
15+
16+
@pytest.mark.parametrize("num_gpus", [1])
17+
def test_pab(num_gpus):
18+
config = VchitectConfig(num_gpus=num_gpus, enable_pab=True)
19+
engine = VideoSysEngine(config)
20+
21+
prompt = "Sunset over the sea."
22+
video = engine.generate(prompt, seed=0).video[0]
23+
engine.save_video(video, f"./test_outputs/{prompt}_vchitect_pab_{num_gpus}.mp4")
24+
25+
26+
@pytest.mark.parametrize("num_gpus", [1])
27+
def test_low_mem(num_gpus):
28+
config = VchitectConfig(num_gpus=num_gpus, cpu_offload=True)
29+
engine = VideoSysEngine(config)
30+
31+
prompt = "Sunset over the sea."
32+
video = engine.generate(prompt, seed=0).video[0]
33+
engine.save_video(video, f"./test_outputs/{prompt}_vchitect_low_mem_{num_gpus}.mp4")

videosys/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
from .pipelines.latte import LatteConfig, LattePABConfig, LattePipeline
55
from .pipelines.open_sora import OpenSoraConfig, OpenSoraPABConfig, OpenSoraPipeline
66
from .pipelines.open_sora_plan import OpenSoraPlanConfig, OpenSoraPlanPABConfig, OpenSoraPlanPipeline
7+
from .pipelines.vchitect import VchitectConfig, VchitectXLPipeline
78

89
__all__ = [
910
"initialize",
1011
"VideoSysEngine",
1112
"LattePipeline", "LatteConfig", "LattePABConfig",
1213
"OpenSoraPlanPipeline", "OpenSoraPlanConfig", "OpenSoraPlanPABConfig",
1314
"OpenSoraPipeline", "OpenSoraConfig", "OpenSoraPABConfig",
14-
"CogVideoXConfig", "CogVideoXPipeline", "CogVideoXPABConfig"
15+
"CogVideoXPipeline", "CogVideoXConfig", "CogVideoXPABConfig",
16+
"VchitectXLPipeline", "VchitectConfig",
1517
] # fmt: skip

videosys/core/comm.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,3 +404,17 @@ def all_to_all_with_pad(
404404
input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad)
405405

406406
return input_
407+
408+
409+
def split_from_second_dim(x, batch_size, parallel_group):
410+
x = x.view(batch_size, -1, *x.shape[1:])
411+
x = split_sequence(x, parallel_group, dim=1, grad_scale="down", pad=get_pad("temporal"))
412+
x = x.reshape(-1, *x.shape[2:])
413+
return x
414+
415+
416+
def gather_from_second_dim(x, batch_size, parallel_group):
417+
x = x.view(batch_size, -1, *x.shape[1:])
418+
x = gather_sequence(x, parallel_group, dim=1, grad_scale="up", pad=get_pad("temporal"))
419+
x = x.reshape(-1, *x.shape[2:])
420+
return x

videosys/core/pab_mgr.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
class PABConfig:
77
def __init__(
88
self,
9-
steps: int,
109
cross_broadcast: bool = False,
1110
cross_threshold: list = None,
1211
cross_range: int = None,
@@ -20,7 +19,7 @@ def __init__(
2019
mlp_spatial_broadcast_config: dict = None,
2120
mlp_temporal_broadcast_config: dict = None,
2221
):
23-
self.steps = steps
22+
self.steps = None
2423

2524
self.cross_broadcast = cross_broadcast
2625
self.cross_threshold = cross_threshold
@@ -45,7 +44,7 @@ class PABManager:
4544
def __init__(self, config: PABConfig):
4645
self.config: PABConfig = config
4746

48-
init_prompt = f"Init Pyramid Attention Broadcast. steps: {config.steps}."
47+
init_prompt = f"Init Pyramid Attention Broadcast."
4948
init_prompt += f" spatial broadcast: {config.spatial_broadcast}, spatial range: {config.spatial_range}, spatial threshold: {config.spatial_threshold}."
5049
init_prompt += f" temporal broadcast: {config.temporal_broadcast}, temporal range: {config.temporal_range}, temporal_threshold: {config.temporal_threshold}."
5150
init_prompt += f" cross broadcast: {config.cross_broadcast}, cross range: {config.cross_range}, cross threshold: {config.cross_threshold}."
@@ -78,7 +77,7 @@ def if_broadcast_temporal(self, timestep: int, count: int):
7877
count = (count + 1) % self.config.steps
7978
return flag, count
8079

81-
def if_broadcast_spatial(self, timestep: int, count: int, block_idx: int):
80+
def if_broadcast_spatial(self, timestep: int, count: int):
8281
if (
8382
self.config.spatial_broadcast
8483
and (timestep is not None)
@@ -213,10 +212,10 @@ def if_broadcast_temporal(timestep: int, count: int):
213212
return PAB_MANAGER.if_broadcast_temporal(timestep, count)
214213

215214

216-
def if_broadcast_spatial(timestep: int, count: int, block_idx: int):
215+
def if_broadcast_spatial(timestep: int, count: int):
217216
if not enable_pab():
218217
return False, count
219-
return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx)
218+
return PAB_MANAGER.if_broadcast_spatial(timestep, count)
220219

221220

222221
def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):

videosys/core/pipeline.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ def __init__(self):
1313

1414
@staticmethod
1515
def set_eval_and_device(device: torch.device, *modules):
16-
for module in modules:
17-
module.eval()
18-
module.to(device)
16+
modules = list(modules)
17+
for i in range(len(modules)):
18+
modules[i] = modules[i].eval()
19+
modules[i] = modules[i].to(device)
1920

2021
@abstractmethod
2122
def generate(self, *args, **kwargs):

0 commit comments

Comments
 (0)