Skip to content

Commit 0e3463f

Browse files
committed
Refactor error handling and subprocess calls, extract common utilities
- Replace `ValueError` with `UserError` for better user feedback and sentry sanity - Copy over centralized HTTP & ffmpeg error handling in `exceptions.py` from gooey-server - Simplify file download logic with `download_file_to_path()` - Avoid too broad exception clause to handle bad face reco in sadtalker - Handle bad face reco in eyeblink and ref_post too - Remove stale wav2lip-src folder
1 parent 8c86f97 commit 0e3463f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+306
-61884
lines changed

chart/model-values.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ deployments:
304304
thenlper/gte-base
305305
306306
- name: "retro-sadtalker"
307-
image: *retroImg
307+
image: "crgooeyprodwestus1.azurecr.io/gooey-gpu-retro:9"
308308
autoscaling:
309309
queueLength: 2
310310
minReplicaCount: 3

common/diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def _safety_checker(clip_input, images):
179179
if not disabled:
180180
images, has_nsfw_concepts = original(images=images, clip_input=clip_input)
181181
if any(has_nsfw_concepts):
182-
raise ValueError(
182+
raise gooey_gpu.UserError(
183183
"Potential NSFW content was detected in one or more images. "
184184
"Try again with a different Prompt and/or Regenerate."
185185
)

deforum_sd/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def deforum(pipeline: PipelineInfo, inputs: deforum_script.DeforumAnimArgs):
4040
headers={"Content-Type": "video/mp4"},
4141
data=vid_bytes,
4242
)
43-
r.raise_for_status()
43+
gooey_gpu.raise_for_status(r)
4444
return
4545

4646

deforum_sd/deforum_script.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,7 @@ def create_video(args: DeforumArgs, anim_args: DeforumAnimArgs):
379379
max_frames = str(anim_args.max_frames)
380380

381381
# make video
382-
cmd = [
383-
"ffmpeg",
384-
"-y",
382+
gooey_gpu.ffmpeg(
385383
"-vcodec",
386384
bitdepth_extension,
387385
"-r",
@@ -405,9 +403,7 @@ def create_video(args: DeforumArgs, anim_args: DeforumAnimArgs):
405403
"-pattern_type",
406404
"sequence",
407405
mp4_path,
408-
]
409-
print(f"---> {' '.join(cmd)}")
410-
subprocess.check_call(cmd)
406+
)
411407
# process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
412408
# stdout, stderr = process.communicate()
413409
# if process.returncode != 0:

exceptions.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import typing
2+
3+
import requests
4+
5+
6+
class UserError(Exception):
7+
def __init__(
8+
self,
9+
message: str,
10+
sentry_level: str = "info",
11+
status_code: typing.Optional[int] = None,
12+
):
13+
self.message = message
14+
self.sentry_level = sentry_level
15+
self.status_code = status_code
16+
super().__init__(
17+
dict(message=message, sentry_level=sentry_level, status_code=status_code)
18+
)
19+
20+
def __str__(self):
21+
return self.message
22+
23+
24+
def raise_for_status(resp: requests.Response, is_user_url: bool = False):
25+
"""Raises :class:`HTTPError`, if one occurred."""
26+
27+
http_error_msg = ""
28+
if isinstance(resp.reason, bytes):
29+
# We attempt to decode utf-8 first because some servers
30+
# choose to localize their reason strings. If the string
31+
# isn't utf-8, we fall back to iso-8859-1 for all other
32+
# encodings. (See PR #3538)
33+
try:
34+
reason = resp.reason.decode("utf-8")
35+
except UnicodeDecodeError:
36+
reason = resp.reason.decode("iso-8859-1")
37+
else:
38+
reason = resp.reason
39+
40+
if 400 <= resp.status_code < 500:
41+
http_error_msg = f"{resp.status_code} Client Error: {reason} | URL: {resp.url} | Response: {_response_preview(resp)!r}"
42+
43+
elif 500 <= resp.status_code < 600:
44+
http_error_msg = f"{resp.status_code} Server Error: {reason} | URL: {resp.url} | Response: {_response_preview(resp)!r}"
45+
46+
if http_error_msg:
47+
exc = requests.HTTPError(http_error_msg, response=resp)
48+
if is_user_url:
49+
raise UserError(
50+
f"[{resp.status_code}] You have provided an invalid URL: {resp.url} "
51+
"Please make sure the URL is correct and accessible. ",
52+
) from exc
53+
else:
54+
raise exc
55+
56+
57+
def _response_preview(resp: requests.Response) -> bytes:
58+
return truncate_filename(resp.content, 500, sep=b"...")
59+
60+
61+
def truncate_filename(
62+
text: typing.AnyStr, maxlen: int = 100, sep: typing.AnyStr = "..."
63+
) -> typing.AnyStr:
64+
if len(text) <= maxlen:
65+
return text
66+
assert len(sep) <= maxlen
67+
mid = (maxlen - len(sep)) // 2
68+
return text[:mid] + sep + text[-mid:]

ffmpeg_util.py

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import numpy as np
77
from pydantic import BaseModel
88

9+
from exceptions import UserError
10+
911

1012
class VideoMetadata(BaseModel):
1113
width: int = 0
@@ -23,40 +25,48 @@ class InputOutputVideoMetadata(BaseModel):
2325

2426
class AudioMetadata(BaseModel):
2527
duration_sec: float = 0
28+
codec_name: typing.Optional[str] = None
2629

2730

2831
def ffprobe_audio(input_path: str) -> AudioMetadata:
29-
cmd_args = [
32+
text = call_cmd(
3033
"ffprobe",
31-
"-v",
32-
"error",
33-
"-show_entries",
34-
"format=duration",
35-
"-of",
36-
"default=noprint_wrappers=1:nokey=1",
37-
input_path,
38-
]
39-
print("\t$ " + " ".join(cmd_args))
34+
"-v", "quiet",
35+
"-print_format", "json",
36+
"-show_streams", input_path,
37+
"-select_streams", "a:0",
38+
) # fmt:skip
39+
data = json.loads(text)
40+
41+
try:
42+
stream = data["streams"][0]
43+
except IndexError:
44+
raise UserError(
45+
"Input has no audio streams. Make sure the you have uploaded an appropriate audio/video file."
46+
)
47+
4048
return AudioMetadata(
41-
duration_sec=float(subprocess.check_output(cmd_args, encoding="utf-8"))
49+
duration_sec=float(stream.get("duration") or 0),
50+
codec_name=stream.get("codec_name"),
4251
)
4352

4453

4554
def ffprobe_video(input_path: str) -> VideoMetadata:
46-
cmd_args = [
55+
text = call_cmd(
4756
"ffprobe",
4857
"-v", "quiet",
4958
"-print_format", "json",
5059
"-show_streams", input_path,
5160
"-select_streams", "v:0",
52-
] # fmt:skip
53-
print("\t$ " + " ".join(cmd_args))
54-
data = json.loads(subprocess.check_output(cmd_args, text=True))
61+
) # fmt:skip
62+
data = json.loads(text)
5563

5664
try:
5765
stream = data["streams"][0]
5866
except IndexError:
59-
raise ValueError("input has no video streams")
67+
raise UserError(
68+
"Input has no video streams. Make sure the video you have uploaded is not corrupted."
69+
)
6070

6171
try:
6272
fps = float(Fraction(stream["avg_frame_rate"]))
@@ -120,3 +130,30 @@ def ffmpeg_get_writer_proc(
120130
] # fmt:skip
121131
print("\t$ " + " ".join(cmd_args))
122132
return subprocess.Popen(cmd_args, stdin=subprocess.PIPE)
133+
134+
135+
FFMPEG_ERR_MSG = (
136+
"Unsupported File Format\n\n"
137+
"We encountered an issue processing your file as it appears to be in a format not supported by our system or may be corrupted. "
138+
"You can find a list of supported formats at [FFmpeg Formats](https://ffmpeg.org/general.html#File-Formats)."
139+
)
140+
141+
142+
def ffmpeg(*args) -> str:
143+
return call_cmd("ffmpeg", "-hide_banner", "-y", *args, err_msg=FFMPEG_ERR_MSG)
144+
145+
146+
def call_cmd(
147+
*args, err_msg: str = "", ok_returncodes: typing.Iterable[int] = ()
148+
) -> str:
149+
print("\t$ " + " ".join(map(str, args)))
150+
try:
151+
return subprocess.check_output(args, stderr=subprocess.STDOUT, text=True)
152+
except subprocess.CalledProcessError as e:
153+
if e.returncode in ok_returncodes:
154+
return e.output
155+
err_msg = err_msg or f"{str(args[0]).capitalize()} Error"
156+
try:
157+
raise subprocess.SubprocessError(e.output) from e
158+
except subprocess.SubprocessError as e:
159+
raise UserError(err_msg) from e

gooey_gpu.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import mimetypes
77
import os
88
import threading
9-
import typing
109
from concurrent.futures import ThreadPoolExecutor
1110
from functools import wraps
1211

@@ -16,7 +15,9 @@
1615
import sentry_sdk
1716
import torch
1817
import transformers
19-
from pydantic import BaseModel
18+
19+
from exceptions import raise_for_status
20+
from ffmpeg_util import *
2021

2122
# from accelerate import cpu_offload_with_hook
2223

@@ -191,7 +192,7 @@ def upload_image(im_pil: PIL.Image.Image, url: str):
191192
headers={"Content-Type": "image/png"},
192193
data=im_bytes,
193194
)
194-
r.raise_for_status()
195+
raise_for_status(r)
195196

196197

197198
def apply_parallel(fn, *iterables):
@@ -220,22 +221,22 @@ def upload_audio(audio, url: str, rate: int = 16_000):
220221

221222
def upload_audio_from_bytes(audio: bytes, url: str):
222223
r = requests.put(url, headers={"Content-Type": "audio/wav"}, data=audio)
223-
r.raise_for_status()
224+
raise_for_status(r)
224225

225226

226227
def upload_video_from_bytes(video, url: str):
227228
r = requests.put(url, headers={"Content-Type": "video/mp4"}, data=video)
228-
r.raise_for_status()
229+
raise_for_status(r)
229230

230231

231232
# Add some missing mimetypes
232233
mimetypes.add_type("audio/wav", ".wav")
233234

234235

235-
def download_file_cached(*, url: str, path: str):
236-
if os.path.exists(path):
236+
def download_file_to_path(*, url: str, path: str, cached: bool = False):
237+
if cached and os.path.exists(path):
237238
return
238239
r = requests.get(url)
239-
r.raise_for_status()
240+
raise_for_status(r, is_user_url=not cached)
240241
with open(path, "wb") as f:
241242
f.write(r.content)

retro/gfpgan.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import typing
44
from functools import lru_cache
55
from tempfile import TemporaryDirectory
6-
from urllib.request import urlretrieve
76

87
import PIL.Image
98
from basicsr.archs.rrdbnet_arch import RRDBNet
@@ -14,13 +13,6 @@
1413

1514
import gooey_gpu
1615
from celeryconfig import app, setup_queues
17-
from ffmpeg_util import (
18-
ffmpeg_get_writer_proc,
19-
ffmpeg_read_input_frames,
20-
ffprobe_video,
21-
VideoMetadata,
22-
InputOutputVideoMetadata,
23-
)
2416

2517
MAX_RES = 1920 * 1080
2618

@@ -40,7 +32,7 @@ class EsrganInputs(BaseModel):
4032
@gooey_gpu.endpoint
4133
def realesrgan(
4234
pipeline: EsrganPipeline, inputs: EsrganInputs
43-
) -> InputOutputVideoMetadata:
35+
) -> gooey_gpu.InputOutputVideoMetadata:
4436
esrganer = load_esrgan_model(pipeline.model_id)
4537

4638
def enhance(frame, outscale_factor):
@@ -71,7 +63,9 @@ class GfpganInputs(BaseModel):
7163

7264
@app.task(name="gfpgan")
7365
@gooey_gpu.endpoint
74-
def gfpgan(pipeline: GfpganPipeline, inputs: GfpganInputs) -> InputOutputVideoMetadata:
66+
def gfpgan(
67+
pipeline: GfpganPipeline, inputs: GfpganInputs
68+
) -> gooey_gpu.InputOutputVideoMetadata:
7569
gfpganer = load_gfpgan_model(pipeline.model_id)
7670
if pipeline.bg_model_id:
7771
gfpganer.bg_upsampler = load_esrgan_model(pipeline.bg_model_id)
@@ -102,24 +96,22 @@ def run_enhancer(
10296
scale: float,
10397
upload_url: str,
10498
enhance: typing.Callable,
105-
) -> InputOutputVideoMetadata:
106-
input_file = image or video
107-
assert input_file, "Please provide an image or video input"
99+
) -> gooey_gpu.InputOutputVideoMetadata:
100+
input_url = image or video
101+
assert input_url, "Please provide an image or video input"
108102

109103
with TemporaryDirectory() as save_dir:
110-
input_path, _ = urlretrieve(
111-
input_file,
112-
os.path.join(save_dir, "input" + os.path.splitext(input_file)[1]),
113-
)
104+
input_path = os.path.join(save_dir, "input" + os.path.splitext(input_url)[1])
105+
gooey_gpu.download_file_to_path(url=input_url, path=input_path)
114106
output_path = os.path.join(save_dir, "out.mp4")
115107

116-
response = InputOutputVideoMetadata(
117-
input=ffprobe_video(input_path), output=VideoMetadata()
108+
response = gooey_gpu.InputOutputVideoMetadata(
109+
input=gooey_gpu.ffprobe_video(input_path), output=gooey_gpu.VideoMetadata()
118110
)
119111
# ensure max input/output is 1080p
120112
input_pixels = response.input.width * response.input.height
121113
if input_pixels > MAX_RES:
122-
raise ValueError(
114+
raise gooey_gpu.UserError(
123115
"Input video resolution exceeds 1920x1080. Please downscale to 1080p."
124116
)
125117
max_scale = math.sqrt(MAX_RES / input_pixels)
@@ -128,7 +120,7 @@ def run_enhancer(
128120

129121
ffproc = None
130122
for frame in tqdm(
131-
ffmpeg_read_input_frames(
123+
gooey_gpu.ffmpeg_read_input_frames(
132124
width=response.input.width,
133125
height=response.input.height,
134126
input_path=input_path,
@@ -152,7 +144,7 @@ def run_enhancer(
152144
response.output.width = restored_img.shape[1]
153145
response.output.height = restored_img.shape[0]
154146
response.output.fps = response.input.fps or 24
155-
ffproc = ffmpeg_get_writer_proc(
147+
ffproc = gooey_gpu.ffmpeg_get_writer_proc(
156148
width=response.output.width,
157149
height=response.output.height,
158150
fps=response.output.fps,
@@ -214,7 +206,7 @@ def load_gfpgan_model(model_id: str) -> "GFPGANer":
214206

215207
print(f"loading {model_id} via {url}...")
216208
model_path = os.path.join(gfpgan_checkpoint_dir, os.path.basename(url))
217-
gooey_gpu.download_file_cached(url=url, path=model_path)
209+
gooey_gpu.download_file_to_path(url=url, path=model_path, cached=True)
218210

219211
return GFPGANer(
220212
model_path=model_path,
@@ -282,7 +274,7 @@ def load_esrgan_model(model_id: str) -> "RealESRGANer":
282274
for url in file_url:
283275
print(f"loading {model_id} via {url}...")
284276
model_path = os.path.join(gooey_gpu.CHECKPOINTS_DIR, os.path.basename(url))
285-
gooey_gpu.download_file_cached(url=url, path=model_path)
277+
gooey_gpu.download_file_to_path(url=url, path=model_path, cached=True)
286278
assert model_path, f"Model {model_id} not found"
287279

288280
return RealESRGANer(

retro/nvidia_nemo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def load_model(model_url: str):
3030
# get cached model path
3131
model_path = os.path.join(gooey_gpu.CHECKPOINTS_DIR, os.path.basename(model_url))
3232
# if not cached, download again
33-
gooey_gpu.download_file_cached(url=model_url, path=model_path)
33+
gooey_gpu.download_file_to_path(url=model_url, path=model_path, cached=True)
3434
# load model
3535
return nemo_asr.models.ASRModel.restore_from(model_path)
3636

0 commit comments

Comments
 (0)