Skip to content

Commit e48a642

Browse files
authored
[gradio] update gradio code and doc (#220)
1 parent f99ad20 commit e48a642

File tree

3 files changed

+52
-101
lines changed

3 files changed

+52
-101
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
<h3 align="center">
55
An easy and efficient system for video generation
66
</h3>
7-
<p align="center">| <a href="https://github.com/NUS-HPC-AI-Lab/VideoSys?tab=readme-ov-file#installation">Quick Start</a> | <a href="https://github.com/NUS-HPC-AI-Lab/VideoSys?tab=readme-ov-file#usage">Supported Models</a> | <a href="https://github.com/NUS-HPC-AI-Lab/VideoSys?tab=readme-ov-file#acceleration-techniques">Accelerations</a> | <a href="https://discord.gg/WhPmYm9FeG">Discord</a> | <a href="https://oahzxl.notion.site/VideoSys-News-42391db7e0a44f96a1f0c341450ae472?pvs=4">Media</a> |
7+
<p align="center">| <a href="https://github.com/NUS-HPC-AI-Lab/VideoSys?tab=readme-ov-file#installation">Quick Start</a> | <a href="https://github.com/NUS-HPC-AI-Lab/VideoSys?tab=readme-ov-file#usage">Supported Models</a> | <a href="https://github.com/NUS-HPC-AI-Lab/VideoSys?tab=readme-ov-file#acceleration-techniques">Accelerations</a> | <a href="https://discord.gg/WhPmYm9FeG">Discord</a> | <a href="https://oahzxl.notion.site/VideoSys-News-42391db7e0a44f96a1f0c341450ae472?pvs=4">Media</a> | <a href="https://huggingface.co/VideoSys">HuggingFace Space</a> |
88
</p>
99

1010
### Latest News 🔥
@@ -106,6 +106,8 @@ VideoSys supports many diffusion models with our various acceleration techniques
106106
</tr>
107107
</table>
108108

109+
You can also find easy demo with HuggingFace Space <a href="https://huggingface.co/VideoSys">[link]</a> and Gradio <a href="./gradio">[link]</a>.
110+
109111
## Acceleration Techniques
110112

111113
### Pyramid Attention Broadcast (PAB) [[paper](https://arxiv.org/abs/2408.12588)][[blog](https://arxiv.org/abs/2403.10266)][[doc](./docs/pab.md)]

gradio/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
## Gradio Demo
2+
Here are local gradio demos for easy UI and visualization. You can also find online demos on <a href="https://huggingface.co/VideoSys">HuggingFace Space</a>.
3+
4+
It's very easy to run the scripts: `python xxx.py`

gradio/cogvideox.py

Lines changed: 45 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,17 @@
33
os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.getcwd(), ".tmp_outputs")
44
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
55

6-
import logging
76
import uuid
87

9-
import GPUtil
10-
import psutil
11-
import torch
8+
import spaces
129

1310
import gradio as gr
1411
from videosys import CogVideoXConfig, CogVideoXPABConfig, VideoSysEngine
1512

16-
logging.basicConfig(level=logging.INFO)
17-
logger = logging.getLogger(__name__)
1813

19-
dtype = torch.float16
20-
21-
22-
def load_model(enable_video_sys=False, pab_threshold=[100, 850], pab_range=2):
14+
def load_model(model_name, enable_video_sys=False, pab_threshold=[100, 850], pab_range=2):
2315
pab_config = CogVideoXPABConfig(spatial_threshold=pab_threshold, spatial_range=pab_range)
24-
config = CogVideoXConfig(num_gpus=1, enable_pab=enable_video_sys, pab_config=pab_config)
16+
config = CogVideoXConfig(model_name, enable_pab=enable_video_sys, pab_config=pab_config)
2517
engine = VideoSysEngine(config)
2618
return engine
2719

@@ -36,33 +28,9 @@ def generate(engine, prompt, num_inference_steps=50, guidance_scale=6.0):
3628
return output_path
3729

3830

39-
def get_server_status():
40-
cpu_percent = psutil.cpu_percent()
41-
memory = psutil.virtual_memory()
42-
disk = psutil.disk_usage("/")
43-
gpus = GPUtil.getGPUs()
44-
gpu_info = []
45-
for gpu in gpus:
46-
gpu_info.append(
47-
{
48-
"id": gpu.id,
49-
"name": gpu.name,
50-
"load": f"{gpu.load*100:.1f}%",
51-
"memory_used": f"{gpu.memoryUsed}MB",
52-
"memory_total": f"{gpu.memoryTotal}MB",
53-
}
54-
)
55-
56-
return {"cpu": f"{cpu_percent}%", "memory": f"{memory.percent}%", "disk": f"{disk.percent}%", "gpu": gpu_info}
57-
58-
59-
def generate_vanilla(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
60-
engine = load_model()
61-
video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
62-
return video_path
63-
64-
31+
@spaces.GPU(duration=200)
6532
def generate_vs(
33+
model_name,
6634
prompt,
6735
num_inference_steps,
6836
guidance_scale,
@@ -73,38 +41,11 @@ def generate_vs(
7341
):
7442
threshold = [int(threshold_end), int(threshold_start)]
7543
gap = int(gap)
76-
engine = load_model(enable_video_sys=True, pab_threshold=threshold, pab_range=gap)
44+
engine = load_model(model_name, enable_video_sys=True, pab_threshold=threshold, pab_range=gap)
7745
video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
7846
return video_path
7947

8048

81-
def get_server_status():
82-
cpu_percent = psutil.cpu_percent()
83-
memory = psutil.virtual_memory()
84-
disk = psutil.disk_usage("/")
85-
try:
86-
gpus = GPUtil.getGPUs()
87-
if gpus:
88-
gpu = gpus[0]
89-
gpu_memory = f"{gpu.memoryUsed}/{gpu.memoryTotal}MB ({gpu.memoryUtil*100:.1f}%)"
90-
else:
91-
gpu_memory = "No GPU found"
92-
except:
93-
gpu_memory = "GPU information unavailable"
94-
95-
return {
96-
"cpu": f"{cpu_percent}%",
97-
"memory": f"{memory.percent}%",
98-
"disk": f"{disk.percent}%",
99-
"gpu_memory": gpu_memory,
100-
}
101-
102-
103-
def update_server_status():
104-
status = get_server_status()
105-
return (status["cpu"], status["memory"], status["disk"], status["gpu_memory"])
106-
107-
10849
css = """
10950
body {
11051
font-family: Arial, sans-serif;
@@ -206,60 +147,64 @@ def update_server_status():
206147

207148
with gr.Row():
208149
with gr.Column():
209-
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", value="Sunset over the sea.", lines=4)
150+
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", value="Sunset over the sea.", lines=2)
210151

211152
with gr.Column():
212153
gr.Markdown("**Generation Parameters**<br>")
213154
with gr.Row():
214-
num_inference_steps = gr.Number(label="Inference Steps", value=50)
215-
guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
155+
model_name = gr.Radio(["THUDM/CogVideoX-2b"], label="Model Type", value="THUDM/CogVideoX-2b")
216156
with gr.Row():
217-
pab_range = gr.Number(
218-
label="PAB Broadcast Range", value=2, precision=0, info="Broadcast timesteps range."
157+
num_inference_steps = gr.Slider(label="Inference Steps", maximum=50, value=50)
158+
guidance_scale = gr.Slider(label="Guidance Scale", value=6.0, maximum=15.0)
159+
gr.Markdown("**Pyramid Attention Broadcast Parameters**<br>")
160+
with gr.Row():
161+
pab_range = gr.Slider(
162+
label="Broadcast Range",
163+
value=2,
164+
step=1,
165+
minimum=1,
166+
maximum=4,
167+
info="Attention broadcast range.",
168+
)
169+
pab_threshold_start = gr.Slider(
170+
label="Start Timestep",
171+
minimum=500,
172+
maximum=1000,
173+
value=850,
174+
step=1,
175+
info="Broadcast start timestep (1000 is the fisrt).",
176+
)
177+
pab_threshold_end = gr.Slider(
178+
label="End Timestep",
179+
minimum=0,
180+
maximum=500,
181+
step=1,
182+
value=100,
183+
info="Broadcast end timestep (0 is the last).",
219184
)
220-
pab_threshold_start = gr.Number(label="PAB Start Timestep", value=850, info="Start from step 1000.")
221-
pab_threshold_end = gr.Number(label="PAB End Timestep", value=100, info="End at step 0.")
222185
with gr.Row():
223-
generate_button_vs = gr.Button("⚡️ Generate Video with VideoSys (Faster)")
224-
generate_button = gr.Button("🎬 Generate Video (Original)")
225-
with gr.Column(elem_classes="server-status"):
226-
gr.Markdown("#### Server Status")
227-
228-
with gr.Row():
229-
cpu_status = gr.Textbox(label="CPU", scale=1)
230-
memory_status = gr.Textbox(label="Memory", scale=1)
231-
232-
with gr.Row():
233-
disk_status = gr.Textbox(label="Disk", scale=1)
234-
gpu_status = gr.Textbox(label="GPU Memory", scale=1)
235-
236-
with gr.Row():
237-
refresh_button = gr.Button("Refresh")
186+
generate_button_vs = gr.Button("⚡️ Generate Video with VideoSys")
238187

239188
with gr.Column():
240189
with gr.Row():
241190
video_output_vs = gr.Video(label="CogVideoX with VideoSys", width=720, height=480)
242-
with gr.Row():
243-
video_output = gr.Video(label="CogVideoX", width=720, height=480)
244-
245-
generate_button.click(
246-
generate_vanilla,
247-
inputs=[prompt, num_inference_steps, guidance_scale],
248-
outputs=[video_output],
249-
concurrency_id="gen",
250-
concurrency_limit=1,
251-
)
252191

253192
generate_button_vs.click(
254193
generate_vs,
255-
inputs=[prompt, num_inference_steps, guidance_scale, pab_threshold_start, pab_threshold_end, pab_range],
194+
inputs=[
195+
model_name,
196+
prompt,
197+
num_inference_steps,
198+
guidance_scale,
199+
pab_threshold_start,
200+
pab_threshold_end,
201+
pab_range,
202+
],
256203
outputs=[video_output_vs],
257204
concurrency_id="gen",
258205
concurrency_limit=1,
259206
)
260207

261-
refresh_button.click(update_server_status, outputs=[cpu_status, memory_status, disk_status, gpu_status])
262-
demo.load(update_server_status, outputs=[cpu_status, memory_status, disk_status, gpu_status], every=1)
263208

264209
if __name__ == "__main__":
265210
demo.queue(max_size=10, default_concurrency_limit=1)

0 commit comments

Comments
 (0)