33os .environ ["GRADIO_TEMP_DIR" ] = os .path .join (os .getcwd (), ".tmp_outputs" )
44os .environ ["PYTORCH_CUDA_ALLOC_CONF" ] = "expandable_segments:True"
55
6- import logging
76import uuid
87
9- import GPUtil
10- import psutil
11- import torch
8+ import spaces
129
1310import gradio as gr
1411from videosys import CogVideoXConfig , CogVideoXPABConfig , VideoSysEngine
1512
16- logging .basicConfig (level = logging .INFO )
17- logger = logging .getLogger (__name__ )
1813
19- dtype = torch .float16
20-
21-
22- def load_model (enable_video_sys = False , pab_threshold = [100 , 850 ], pab_range = 2 ):
14+ def load_model (model_name , enable_video_sys = False , pab_threshold = [100 , 850 ], pab_range = 2 ):
2315 pab_config = CogVideoXPABConfig (spatial_threshold = pab_threshold , spatial_range = pab_range )
24- config = CogVideoXConfig (num_gpus = 1 , enable_pab = enable_video_sys , pab_config = pab_config )
16+ config = CogVideoXConfig (model_name , enable_pab = enable_video_sys , pab_config = pab_config )
2517 engine = VideoSysEngine (config )
2618 return engine
2719
@@ -36,33 +28,9 @@ def generate(engine, prompt, num_inference_steps=50, guidance_scale=6.0):
3628 return output_path
3729
3830
39- def get_server_status ():
40- cpu_percent = psutil .cpu_percent ()
41- memory = psutil .virtual_memory ()
42- disk = psutil .disk_usage ("/" )
43- gpus = GPUtil .getGPUs ()
44- gpu_info = []
45- for gpu in gpus :
46- gpu_info .append (
47- {
48- "id" : gpu .id ,
49- "name" : gpu .name ,
50- "load" : f"{ gpu .load * 100 :.1f} %" ,
51- "memory_used" : f"{ gpu .memoryUsed } MB" ,
52- "memory_total" : f"{ gpu .memoryTotal } MB" ,
53- }
54- )
55-
56- return {"cpu" : f"{ cpu_percent } %" , "memory" : f"{ memory .percent } %" , "disk" : f"{ disk .percent } %" , "gpu" : gpu_info }
57-
58-
59- def generate_vanilla (prompt , num_inference_steps , guidance_scale , progress = gr .Progress (track_tqdm = True )):
60- engine = load_model ()
61- video_path = generate (engine , prompt , num_inference_steps , guidance_scale )
62- return video_path
63-
64-
31+ @spaces .GPU (duration = 200 )
6532def generate_vs (
33+ model_name ,
6634 prompt ,
6735 num_inference_steps ,
6836 guidance_scale ,
@@ -73,38 +41,11 @@ def generate_vs(
7341):
7442 threshold = [int (threshold_end ), int (threshold_start )]
7543 gap = int (gap )
76- engine = load_model (enable_video_sys = True , pab_threshold = threshold , pab_range = gap )
44+ engine = load_model (model_name , enable_video_sys = True , pab_threshold = threshold , pab_range = gap )
7745 video_path = generate (engine , prompt , num_inference_steps , guidance_scale )
7846 return video_path
7947
8048
81- def get_server_status ():
82- cpu_percent = psutil .cpu_percent ()
83- memory = psutil .virtual_memory ()
84- disk = psutil .disk_usage ("/" )
85- try :
86- gpus = GPUtil .getGPUs ()
87- if gpus :
88- gpu = gpus [0 ]
89- gpu_memory = f"{ gpu .memoryUsed } /{ gpu .memoryTotal } MB ({ gpu .memoryUtil * 100 :.1f} %)"
90- else :
91- gpu_memory = "No GPU found"
92- except :
93- gpu_memory = "GPU information unavailable"
94-
95- return {
96- "cpu" : f"{ cpu_percent } %" ,
97- "memory" : f"{ memory .percent } %" ,
98- "disk" : f"{ disk .percent } %" ,
99- "gpu_memory" : gpu_memory ,
100- }
101-
102-
103- def update_server_status ():
104- status = get_server_status ()
105- return (status ["cpu" ], status ["memory" ], status ["disk" ], status ["gpu_memory" ])
106-
107-
10849css = """
10950body {
11051 font-family: Arial, sans-serif;
@@ -206,60 +147,64 @@ def update_server_status():
206147
207148 with gr .Row ():
208149 with gr .Column ():
209- prompt = gr .Textbox (label = "Prompt (Less than 200 Words)" , value = "Sunset over the sea." , lines = 4 )
150+ prompt = gr .Textbox (label = "Prompt (Less than 200 Words)" , value = "Sunset over the sea." , lines = 2 )
210151
211152 with gr .Column ():
212153 gr .Markdown ("**Generation Parameters**<br>" )
213154 with gr .Row ():
214- num_inference_steps = gr .Number (label = "Inference Steps" , value = 50 )
215- guidance_scale = gr .Number (label = "Guidance Scale" , value = 6.0 )
155+ model_name = gr .Radio (["THUDM/CogVideoX-2b" ], label = "Model Type" , value = "THUDM/CogVideoX-2b" )
216156 with gr .Row ():
217- pab_range = gr .Number (
218- label = "PAB Broadcast Range" , value = 2 , precision = 0 , info = "Broadcast timesteps range."
157+ num_inference_steps = gr .Slider (label = "Inference Steps" , maximum = 50 , value = 50 )
158+ guidance_scale = gr .Slider (label = "Guidance Scale" , value = 6.0 , maximum = 15.0 )
159+ gr .Markdown ("**Pyramid Attention Broadcast Parameters**<br>" )
160+ with gr .Row ():
161+ pab_range = gr .Slider (
162+ label = "Broadcast Range" ,
163+ value = 2 ,
164+ step = 1 ,
165+ minimum = 1 ,
166+ maximum = 4 ,
167+ info = "Attention broadcast range." ,
168+ )
169+ pab_threshold_start = gr .Slider (
170+ label = "Start Timestep" ,
171+ minimum = 500 ,
172+ maximum = 1000 ,
173+ value = 850 ,
174+ step = 1 ,
175+ info = "Broadcast start timestep (1000 is the fisrt)." ,
176+ )
177+ pab_threshold_end = gr .Slider (
178+ label = "End Timestep" ,
179+ minimum = 0 ,
180+ maximum = 500 ,
181+ step = 1 ,
182+ value = 100 ,
183+ info = "Broadcast end timestep (0 is the last)." ,
219184 )
220- pab_threshold_start = gr .Number (label = "PAB Start Timestep" , value = 850 , info = "Start from step 1000." )
221- pab_threshold_end = gr .Number (label = "PAB End Timestep" , value = 100 , info = "End at step 0." )
222185 with gr .Row ():
223- generate_button_vs = gr .Button ("⚡️ Generate Video with VideoSys (Faster)" )
224- generate_button = gr .Button ("🎬 Generate Video (Original)" )
225- with gr .Column (elem_classes = "server-status" ):
226- gr .Markdown ("#### Server Status" )
227-
228- with gr .Row ():
229- cpu_status = gr .Textbox (label = "CPU" , scale = 1 )
230- memory_status = gr .Textbox (label = "Memory" , scale = 1 )
231-
232- with gr .Row ():
233- disk_status = gr .Textbox (label = "Disk" , scale = 1 )
234- gpu_status = gr .Textbox (label = "GPU Memory" , scale = 1 )
235-
236- with gr .Row ():
237- refresh_button = gr .Button ("Refresh" )
186+ generate_button_vs = gr .Button ("⚡️ Generate Video with VideoSys" )
238187
239188 with gr .Column ():
240189 with gr .Row ():
241190 video_output_vs = gr .Video (label = "CogVideoX with VideoSys" , width = 720 , height = 480 )
242- with gr .Row ():
243- video_output = gr .Video (label = "CogVideoX" , width = 720 , height = 480 )
244-
245- generate_button .click (
246- generate_vanilla ,
247- inputs = [prompt , num_inference_steps , guidance_scale ],
248- outputs = [video_output ],
249- concurrency_id = "gen" ,
250- concurrency_limit = 1 ,
251- )
252191
253192 generate_button_vs .click (
254193 generate_vs ,
255- inputs = [prompt , num_inference_steps , guidance_scale , pab_threshold_start , pab_threshold_end , pab_range ],
194+ inputs = [
195+ model_name ,
196+ prompt ,
197+ num_inference_steps ,
198+ guidance_scale ,
199+ pab_threshold_start ,
200+ pab_threshold_end ,
201+ pab_range ,
202+ ],
256203 outputs = [video_output_vs ],
257204 concurrency_id = "gen" ,
258205 concurrency_limit = 1 ,
259206 )
260207
261- refresh_button .click (update_server_status , outputs = [cpu_status , memory_status , disk_status , gpu_status ])
262- demo .load (update_server_status , outputs = [cpu_status , memory_status , disk_status , gpu_status ], every = 1 )
263208
264209if __name__ == "__main__" :
265210 demo .queue (max_size = 10 , default_concurrency_limit = 1 )
0 commit comments