Skip to content

Commit 2e81cd8

Browse files
Implemented cpu offload (#197)
* Implemented cpu offload for other frameworks * polish example code * Added test for low mem settings. * polish * format * fix arg * update pipeline and test --------- Co-authored-by: ExtremeViscent <[email protected]>
1 parent 4ad17b5 commit 2e81cd8

File tree

14 files changed

+136
-27
lines changed

14 files changed

+136
-27
lines changed

examples/latte/sample.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,15 @@ def run_base():
1616
engine.save_video(video, f"./outputs/{prompt}.mp4")
1717

1818

19+
def run_low_mem():
20+
config = LatteConfig("maxin-cn/Latte-1", cpu_offload=True)
21+
engine = VideoSysEngine(config)
22+
23+
prompt = "Sunset over the sea."
24+
video = engine.generate(prompt).video[0]
25+
engine.save_video(video, f"./outputs/{prompt}.mp4")
26+
27+
1928
def run_pab():
2029
config = LatteConfig("maxin-cn/Latte-1", enable_pab=True)
2130
engine = VideoSysEngine(config)
@@ -27,4 +36,5 @@ def run_pab():
2736

2837
if __name__ == "__main__":
2938
run_base()
39+
# run_low_mem()
3040
# run_pab()

examples/open_sora/sample.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@ def run_base():
2020
engine.save_video(video, f"./outputs/{prompt}.mp4")
2121

2222

23+
def run_low_mem():
24+
config = OpenSoraConfig(cpu_offload=True, tiling_size=1)
25+
engine = VideoSysEngine(config)
26+
27+
prompt = "Sunset over the sea."
28+
video = engine.generate(prompt).video[0]
29+
engine.save_video(video, f"./outputs/{prompt}.mp4")
30+
31+
2332
def run_pab():
2433
config = OpenSoraConfig(enable_pab=True)
2534
engine = VideoSysEngine(config)
@@ -31,4 +40,5 @@ def run_pab():
3140

3241
if __name__ == "__main__":
3342
run_base()
43+
# run_low_mem()
3444
# run_pab()

examples/open_sora_plan/sample.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,17 @@ def run_base():
1616
engine.save_video(video, f"./outputs/{prompt}.mp4")
1717

1818

19+
def run_low_mem():
20+
config = OpenSoraPlanConfig(cpu_offload=True, enable_tiling=True)
21+
engine = VideoSysEngine(config)
22+
23+
prompt = "Sunset over the sea."
24+
video = engine.generate(prompt).video[0]
25+
engine.save_video(video, f"./outputs/{prompt}.mp4")
26+
27+
1928
def run_pab():
20-
config = OpenSoraPlanConfig(num_gpus=1, enable_pab=True)
29+
config = OpenSoraPlanConfig(enable_pab=True)
2130
engine = VideoSysEngine(config)
2231

2332
prompt = "Sunset over the sea."
@@ -27,4 +36,5 @@ def run_pab():
2736

2837
if __name__ == "__main__":
2938
run_base()
39+
# run_low_mem()
3040
# run_pab()

tests/examples/test_sample.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,19 @@
1212
import examples.open_sora.sample as open_sora
1313
import examples.open_sora_plan.sample as open_sora_plan
1414

15+
files = [cogvideox, latte, open_sora, open_sora_plan]
16+
members = []
1517

16-
@pytest.mark.parametrize("file", [cogvideox, latte, open_sora, open_sora_plan])
17-
def test_examples(file):
18-
funcs = inspect.getmembers(file, inspect.isfunction)
19-
for name, func in funcs:
20-
try:
21-
func()
22-
except Exception as e:
23-
raise Exception(f"Failed to run {name} in {file.__file__}") from e
18+
for file in files:
19+
for m in inspect.getmembers(file, inspect.isfunction):
20+
members.append(m)
21+
print(members)
22+
23+
24+
@pytest.mark.parametrize("members", members)
25+
def test_examples(members):
26+
name, func = members
27+
try:
28+
func()
29+
except Exception as e:
30+
raise Exception(f"Failed to run {name} in {file.__file__}") from e

tests/pipelines/cogvideox/test_cogvideox.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,13 @@ def test_pab(num_gpus):
2121
prompt = "Sunset over the sea."
2222
video = engine.generate(prompt).video[0]
2323
engine.save_video(video, f"./test_outputs/{prompt}_cogvideo_pab_{num_gpus}.mp4")
24+
25+
26+
@pytest.mark.parametrize("num_gpus", [1])
27+
def test_low_mem(num_gpus):
28+
config = CogVideoXConfig(num_gpus=num_gpus, cpu_offload=True, vae_tiling=True)
29+
engine = VideoSysEngine(config)
30+
31+
prompt = "Sunset over the sea."
32+
video = engine.generate(prompt).video[0]
33+
engine.save_video(video, f"./test_outputs/{prompt}_cogvideo_low_mem_{num_gpus}.mp4")

tests/pipelines/latte/test_latte.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,13 @@ def test_pab(num_gpus):
2121
prompt = "Sunset over the sea."
2222
video = engine.generate(prompt).video[0]
2323
engine.save_video(video, f"./test_outputs/{prompt}_latte_pab_{num_gpus}.mp4")
24+
25+
26+
@pytest.mark.parametrize("num_gpus", [1])
27+
def test_low_mem(num_gpus):
28+
config = LatteConfig(num_gpus=num_gpus, cpu_offload=True)
29+
engine = VideoSysEngine(config)
30+
31+
prompt = "Sunset over the sea."
32+
video = engine.generate(prompt).video[0]
33+
engine.save_video(video, f"./test_outputs/{prompt}_latte_low_mem_{num_gpus}.mp4")

tests/pipelines/open_sora/test_open_sora.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,13 @@ def test_pab(num_gpus):
2121
prompt = "Sunset over the sea."
2222
video = engine.generate(prompt).video[0]
2323
engine.save_video(video, f"./test_outputs/{prompt}_open_sora_pab_{num_gpus}.mp4")
24+
25+
26+
@pytest.mark.parametrize("num_gpus", [1])
27+
def test_low_mem(num_gpus):
28+
config = OpenSoraConfig(num_gpus=num_gpus, cpu_offload=True, tiling_size=1)
29+
engine = VideoSysEngine(config)
30+
31+
prompt = "Sunset over the sea."
32+
video = engine.generate(prompt).video[0]
33+
engine.save_video(video, f"./test_outputs/{prompt}_open_sora_low_mem_{num_gpus}.mp4")

tests/pipelines/open_sora_plan/test_open_sora_plan.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,13 @@ def test_pab(num_gpus):
2121
prompt = "Sunset over the sea."
2222
video = engine.generate(prompt).video[0]
2323
engine.save_video(video, f"./test_outputs/{prompt}_open_sora_plan_pab_{num_gpus}.mp4")
24+
25+
26+
@pytest.mark.parametrize("num_gpus", [1])
27+
def test_low_mem(num_gpus):
28+
config = OpenSoraPlanConfig(num_gpus=num_gpus, cpu_offload=True, enable_tiling=True)
29+
engine = VideoSysEngine(config)
30+
31+
prompt = "Sunset over the sea."
32+
video = engine.generate(prompt).video[0]
33+
engine.save_video(video, f"./test_outputs/{prompt}_open_sora_plan_low_mem_{num_gpus}.mp4")

videosys/core/engine.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, Optional
44

55
import torch
6+
import torch.distributed as dist
67

78
import videosys
89

@@ -22,9 +23,6 @@ def __init__(self, config):
2223
def _init_worker(self, pipeline_cls):
2324
world_size = self.config.num_gpus
2425

25-
if "CUDA_VISIBLE_DEVICES" not in os.environ:
26-
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in range(world_size))
27-
2826
# Disable torch async compiling which won't work with daemonic processes
2927
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
3028

@@ -124,7 +122,7 @@ def save_video(self, video, output_path):
124122
def shutdown(self):
125123
if (worker_monitor := getattr(self, "worker_monitor", None)) is not None:
126124
worker_monitor.close()
127-
torch.distributed.destroy_process_group()
125+
dist.destroy_process_group()
128126

129127
def __del__(self):
130128
self.shutdown()

videosys/models/autoencoders/autoencoder_kl_open_sora.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,9 @@ def encode(self, x):
670670
return (z - self.shift) / self.scale
671671

672672
def decode(self, z, num_frames=None):
673+
device = z.device
674+
self.scale = self.scale.to(device)
675+
self.shift = self.shift.to(device)
673676
if not self.cal_loss:
674677
z = z * self.scale.to(z.dtype) + self.shift.to(z.dtype)
675678

0 commit comments

Comments
 (0)