33import typing
44from functools import lru_cache
55from tempfile import TemporaryDirectory
6- from urllib .request import urlretrieve
76
87import PIL .Image
98from basicsr .archs .rrdbnet_arch import RRDBNet
1413
1514import gooey_gpu
1615from celeryconfig import app , setup_queues
17- from ffmpeg_util import (
18- ffmpeg_get_writer_proc ,
19- ffmpeg_read_input_frames ,
20- ffprobe_video ,
21- VideoMetadata ,
22- InputOutputVideoMetadata ,
23- )
2416
2517MAX_RES = 1920 * 1080
2618
@@ -40,7 +32,7 @@ class EsrganInputs(BaseModel):
4032@gooey_gpu .endpoint
4133def realesrgan (
4234 pipeline : EsrganPipeline , inputs : EsrganInputs
43- ) -> InputOutputVideoMetadata :
35+ ) -> gooey_gpu . InputOutputVideoMetadata :
4436 esrganer = load_esrgan_model (pipeline .model_id )
4537
4638 def enhance (frame , outscale_factor ):
@@ -71,7 +63,9 @@ class GfpganInputs(BaseModel):
7163
7264@app .task (name = "gfpgan" )
7365@gooey_gpu .endpoint
74- def gfpgan (pipeline : GfpganPipeline , inputs : GfpganInputs ) -> InputOutputVideoMetadata :
66+ def gfpgan (
67+ pipeline : GfpganPipeline , inputs : GfpganInputs
68+ ) -> gooey_gpu .InputOutputVideoMetadata :
7569 gfpganer = load_gfpgan_model (pipeline .model_id )
7670 if pipeline .bg_model_id :
7771 gfpganer .bg_upsampler = load_esrgan_model (pipeline .bg_model_id )
@@ -102,24 +96,22 @@ def run_enhancer(
10296 scale : float ,
10397 upload_url : str ,
10498 enhance : typing .Callable ,
105- ) -> InputOutputVideoMetadata :
106- input_file = image or video
107- assert input_file , "Please provide an image or video input"
99+ ) -> gooey_gpu . InputOutputVideoMetadata :
100+ input_url = image or video
101+ assert input_url , "Please provide an image or video input"
108102
109103 with TemporaryDirectory () as save_dir :
110- input_path , _ = urlretrieve (
111- input_file ,
112- os .path .join (save_dir , "input" + os .path .splitext (input_file )[1 ]),
113- )
104+ input_path = os .path .join (save_dir , "input" + os .path .splitext (input_url )[1 ])
105+ gooey_gpu .download_file_to_path (url = input_url , path = input_path )
114106 output_path = os .path .join (save_dir , "out.mp4" )
115107
116- response = InputOutputVideoMetadata (
117- input = ffprobe_video (input_path ), output = VideoMetadata ()
108+ response = gooey_gpu . InputOutputVideoMetadata (
109+ input = gooey_gpu . ffprobe_video (input_path ), output = gooey_gpu . VideoMetadata ()
118110 )
119111 # ensure max input/output is 1080p
120112 input_pixels = response .input .width * response .input .height
121113 if input_pixels > MAX_RES :
122- raise ValueError (
114+ raise gooey_gpu . UserError (
123115 "Input video resolution exceeds 1920x1080. Please downscale to 1080p."
124116 )
125117 max_scale = math .sqrt (MAX_RES / input_pixels )
@@ -128,7 +120,7 @@ def run_enhancer(
128120
129121 ffproc = None
130122 for frame in tqdm (
131- ffmpeg_read_input_frames (
123+ gooey_gpu . ffmpeg_read_input_frames (
132124 width = response .input .width ,
133125 height = response .input .height ,
134126 input_path = input_path ,
@@ -152,7 +144,7 @@ def run_enhancer(
152144 response .output .width = restored_img .shape [1 ]
153145 response .output .height = restored_img .shape [0 ]
154146 response .output .fps = response .input .fps or 24
155- ffproc = ffmpeg_get_writer_proc (
147+ ffproc = gooey_gpu . ffmpeg_get_writer_proc (
156148 width = response .output .width ,
157149 height = response .output .height ,
158150 fps = response .output .fps ,
@@ -214,7 +206,7 @@ def load_gfpgan_model(model_id: str) -> "GFPGANer":
214206
215207 print (f"loading { model_id } via { url } ..." )
216208 model_path = os .path .join (gfpgan_checkpoint_dir , os .path .basename (url ))
217- gooey_gpu .download_file_cached (url = url , path = model_path )
209+ gooey_gpu .download_file_to_path (url = url , path = model_path , cached = True )
218210
219211 return GFPGANer (
220212 model_path = model_path ,
@@ -282,7 +274,7 @@ def load_esrgan_model(model_id: str) -> "RealESRGANer":
282274 for url in file_url :
283275 print (f"loading { model_id } via { url } ..." )
284276 model_path = os .path .join (gooey_gpu .CHECKPOINTS_DIR , os .path .basename (url ))
285- gooey_gpu .download_file_cached (url = url , path = model_path )
277+ gooey_gpu .download_file_to_path (url = url , path = model_path , cached = True )
286278 assert model_path , f"Model { model_id } not found"
287279
288280 return RealESRGANer (
0 commit comments