Skip to content

Commit 98531a3

Browse files
SanderGidevxpy
authored andcommitted
handle top 8 errors
1 parent d50cee8 commit 98531a3

File tree

1 file changed

+97
-10
lines changed

1 file changed

+97
-10
lines changed

retro/sadtalker.py

Lines changed: 97 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import typing
66
from functools import lru_cache
77
from tempfile import TemporaryDirectory
8-
from urllib.request import urlretrieve
8+
import requests
9+
import random
910

1011
import cv2
1112
import numpy as np
@@ -29,6 +30,7 @@
2930

3031
from src.facerender.animate import AnimateFromCoeff
3132
from src.generate_batch import get_data
33+
from src import generate_batch as gb
3234
from src.generate_facerender_batch import get_facerender_data
3335
from src.test_audio2coeff import Audio2Coeff
3436
from src.utils.init_path import init_path
@@ -39,6 +41,44 @@
3941
MAX_RES = 1920 * 1080
4042

4143

44+
# the original sadtalker function does not work for short audio
45+
def fixed_generate_blink_seq_randomly(num_frames):
46+
ratio = np.zeros((num_frames, 1))
47+
if int(num_frames / 2) <= 11:
48+
return ratio
49+
frame_id = 0
50+
while frame_id in range(num_frames):
51+
start = random.choice(range(min(10, num_frames), min(int(num_frames / 2), 70)))
52+
if frame_id + start + 5 <= num_frames - 1:
53+
ratio[frame_id + start : frame_id + start + 5, 0] = [
54+
0.5,
55+
0.9,
56+
1.0,
57+
0.9,
58+
0.5,
59+
]
60+
frame_id = frame_id + start + 5
61+
else:
62+
break
63+
return ratio
64+
65+
66+
# so we patch in a fixed version
67+
gb.generate_blink_seq_randomly = fixed_generate_blink_seq_randomly
68+
69+
70+
def urlretrieve(url, filename):
71+
"""Same as urllib.urlretrieve but uses requests because urllib breaks on discord attachments. Does not support data: URLs and local files."""
72+
res = requests.get(url)
73+
if not res.ok:
74+
raise ValueError(
75+
f"Could not access user provided url: {url} ({res.status_code} {res.reason}"
76+
)
77+
with open(filename, "wb") as f:
78+
f.write(res.content)
79+
return filename, None
80+
81+
4282
class SadtalkerPipeline(BaseModel):
4383
upload_urls: typing.List[HttpUrl] # upload url for the output video
4484
model_id: str
@@ -123,11 +163,11 @@ def sadtalker(
123163
) -> InputOutputVideoMetadata:
124164
assert len(pipeline.upload_urls) == 1, "Expected exactly 1 upload url"
125165

126-
face_mime_type = mimetypes.guess_type(inputs.source_image)[0] or ""
166+
face_mime_type = mimetypes.guess_type(inputs.source_image.split("?")[0])[0] or ""
127167
if not ("video/" in face_mime_type or "image/" in face_mime_type):
128168
raise ValueError(f"Unsupported face format {face_mime_type!r}")
129169

130-
audio_mime_type = mimetypes.guess_type(inputs.driven_audio)[0] or ""
170+
audio_mime_type = mimetypes.guess_type(inputs.driven_audio.split("?")[0])[0] or ""
131171
if not ("audio/" in audio_mime_type or "video/" in audio_mime_type):
132172
raise ValueError(f"Unsupported audio format {audio_mime_type!r}")
133173

@@ -136,6 +176,26 @@ def sadtalker(
136176
inputs.source_image,
137177
os.path.join(save_dir, "face" + os.path.splitext(inputs.source_image)[1]),
138178
)
179+
# convert image to jpg (to remove transparency) and make smaller than MAX_RES
180+
if face_mime_type.startswith("image/"):
181+
args = [
182+
"ffmpeg",
183+
"-y",
184+
"-i",
185+
input_path,
186+
"-vf",
187+
"scale=w=1920:h=1080:force_original_aspect_ratio=decrease",
188+
"-q:v",
189+
"1",
190+
"-frames:v",
191+
"1",
192+
"-pix_fmt",
193+
"yuv420p",
194+
os.path.splitext(input_path)[0] + ".jpg",
195+
]
196+
subprocess.check_output(args, encoding="utf-8")
197+
input_path = os.path.splitext(input_path)[0] + ".jpg"
198+
139199
audio_path, _ = urlretrieve(
140200
inputs.driven_audio,
141201
os.path.join(save_dir, "audio" + os.path.splitext(inputs.driven_audio)[1]),
@@ -151,6 +211,24 @@ def sadtalker(
151211
print("\t$ " + " ".join(args))
152212
print(subprocess.check_output(args, encoding="utf-8"))
153213
audio_path = wav_audio_path
214+
# make sure audio is not 0 seconds
215+
audio_length = float(
216+
subprocess.check_output(
217+
[
218+
"ffprobe",
219+
"-v",
220+
"error",
221+
"-show_entries",
222+
"format=duration",
223+
"-of",
224+
"default=noprint_wrappers=1:nokey=1",
225+
audio_path,
226+
],
227+
encoding="utf-8",
228+
)
229+
)
230+
if audio_length <= 0.1:
231+
raise ValueError("Audio is too short")
154232

155233
response = InputOutputVideoMetadata(
156234
input=ffprobe_video(input_path), output=VideoMetadata()
@@ -167,13 +245,22 @@ def sadtalker(
167245
# crop image and extract 3dmm from image
168246
first_frame_dir = os.path.join(save_dir, "first_frame_dir")
169247
os.makedirs(first_frame_dir, exist_ok=True)
170-
first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(
171-
input_path,
172-
first_frame_dir,
173-
pipeline.preprocess,
174-
source_image_flag=True,
175-
pic_size=pipeline.size,
176-
)
248+
try:
249+
first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(
250+
input_path,
251+
first_frame_dir,
252+
pipeline.preprocess,
253+
source_image_flag=True,
254+
pic_size=pipeline.size,
255+
)
256+
except Exception as e:
257+
if face_mime_type.startswith("video/"):
258+
raise ValueError(
259+
"Could not identify the face in the video. Wav2Lip generally works better with videos."
260+
) from e
261+
raise ValueError(
262+
"Could not identify the face in the image. Please try another image. Humanoid faces and solid backgrounds work best.",
263+
) from e
177264
if first_coeff_path is None:
178265
raise ValueError("Can't get the coeffs of the input")
179266

0 commit comments

Comments
 (0)