55import typing
66from functools import lru_cache
77from tempfile import TemporaryDirectory
8- from urllib .request import urlretrieve
8+ import requests
9+ import random
910
1011import cv2
1112import numpy as np
2930
3031from src .facerender .animate import AnimateFromCoeff
3132from src .generate_batch import get_data
33+ from src import generate_batch as gb
3234from src .generate_facerender_batch import get_facerender_data
3335from src .test_audio2coeff import Audio2Coeff
3436from src .utils .init_path import init_path
3941MAX_RES = 1920 * 1080
4042
4143
44+ # the original sadtalker function does not work for short audio
45+ def fixed_generate_blink_seq_randomly (num_frames ):
46+ ratio = np .zeros ((num_frames , 1 ))
47+ if int (num_frames / 2 ) <= 11 :
48+ return ratio
49+ frame_id = 0
50+ while frame_id in range (num_frames ):
51+ start = random .choice (range (min (10 , num_frames ), min (int (num_frames / 2 ), 70 )))
52+ if frame_id + start + 5 <= num_frames - 1 :
53+ ratio [frame_id + start : frame_id + start + 5 , 0 ] = [
54+ 0.5 ,
55+ 0.9 ,
56+ 1.0 ,
57+ 0.9 ,
58+ 0.5 ,
59+ ]
60+ frame_id = frame_id + start + 5
61+ else :
62+ break
63+ return ratio
64+
65+
66+ # so we patch in a fixed version
67+ gb .generate_blink_seq_randomly = fixed_generate_blink_seq_randomly
68+
69+
70+ def urlretrieve (url , filename ):
71+ """Same as urllib.urlretrieve but uses requests because urllib breaks on discord attachments. Does not support data: URLs and local files."""
72+ res = requests .get (url )
73+ if not res .ok :
74+ raise ValueError (
75+ f"Could not access user provided url: { url } ({ res .status_code } { res .reason } "
76+ )
77+ with open (filename , "wb" ) as f :
78+ f .write (res .content )
79+ return filename , None
80+
81+
4282class SadtalkerPipeline (BaseModel ):
4383 upload_urls : typing .List [HttpUrl ] # upload url for the output video
4484 model_id : str
@@ -123,11 +163,11 @@ def sadtalker(
123163) -> InputOutputVideoMetadata :
124164 assert len (pipeline .upload_urls ) == 1 , "Expected exactly 1 upload url"
125165
126- face_mime_type = mimetypes .guess_type (inputs .source_image )[0 ] or ""
166+ face_mime_type = mimetypes .guess_type (inputs .source_image . split ( "?" )[ 0 ] )[0 ] or ""
127167 if not ("video/" in face_mime_type or "image/" in face_mime_type ):
128168 raise ValueError (f"Unsupported face format { face_mime_type !r} " )
129169
130- audio_mime_type = mimetypes .guess_type (inputs .driven_audio )[0 ] or ""
170+ audio_mime_type = mimetypes .guess_type (inputs .driven_audio . split ( "?" )[ 0 ] )[0 ] or ""
131171 if not ("audio/" in audio_mime_type or "video/" in audio_mime_type ):
132172 raise ValueError (f"Unsupported audio format { audio_mime_type !r} " )
133173
@@ -136,6 +176,26 @@ def sadtalker(
136176 inputs .source_image ,
137177 os .path .join (save_dir , "face" + os .path .splitext (inputs .source_image )[1 ]),
138178 )
179+ # convert image to jpg (to remove transparency) and make smaller than MAX_RES
180+ if face_mime_type .startswith ("image/" ):
181+ args = [
182+ "ffmpeg" ,
183+ "-y" ,
184+ "-i" ,
185+ input_path ,
186+ "-vf" ,
187+ "scale=w=1920:h=1080:force_original_aspect_ratio=decrease" ,
188+ "-q:v" ,
189+ "1" ,
190+ "-frames:v" ,
191+ "1" ,
192+ "-pix_fmt" ,
193+ "yuv420p" ,
194+ os .path .splitext (input_path )[0 ] + ".jpg" ,
195+ ]
196+ subprocess .check_output (args , encoding = "utf-8" )
197+ input_path = os .path .splitext (input_path )[0 ] + ".jpg"
198+
139199 audio_path , _ = urlretrieve (
140200 inputs .driven_audio ,
141201 os .path .join (save_dir , "audio" + os .path .splitext (inputs .driven_audio )[1 ]),
@@ -151,6 +211,24 @@ def sadtalker(
151211 print ("\t $ " + " " .join (args ))
152212 print (subprocess .check_output (args , encoding = "utf-8" ))
153213 audio_path = wav_audio_path
214+ # make sure audio is not 0 seconds
215+ audio_length = float (
216+ subprocess .check_output (
217+ [
218+ "ffprobe" ,
219+ "-v" ,
220+ "error" ,
221+ "-show_entries" ,
222+ "format=duration" ,
223+ "-of" ,
224+ "default=noprint_wrappers=1:nokey=1" ,
225+ audio_path ,
226+ ],
227+ encoding = "utf-8" ,
228+ )
229+ )
230+ if audio_length <= 0.1 :
231+ raise ValueError ("Audio is too short" )
154232
155233 response = InputOutputVideoMetadata (
156234 input = ffprobe_video (input_path ), output = VideoMetadata ()
@@ -167,13 +245,22 @@ def sadtalker(
167245 # crop image and extract 3dmm from image
168246 first_frame_dir = os .path .join (save_dir , "first_frame_dir" )
169247 os .makedirs (first_frame_dir , exist_ok = True )
170- first_coeff_path , crop_pic_path , crop_info = preprocess_model .generate (
171- input_path ,
172- first_frame_dir ,
173- pipeline .preprocess ,
174- source_image_flag = True ,
175- pic_size = pipeline .size ,
176- )
248+ try :
249+ first_coeff_path , crop_pic_path , crop_info = preprocess_model .generate (
250+ input_path ,
251+ first_frame_dir ,
252+ pipeline .preprocess ,
253+ source_image_flag = True ,
254+ pic_size = pipeline .size ,
255+ )
256+ except Exception as e :
257+ if face_mime_type .startswith ("video/" ):
258+ raise ValueError (
259+ "Could not identify the face in the video. Wav2Lip generally works better with videos."
260+ ) from e
261+ raise ValueError (
262+ "Could not identify the face in the image. Please try another image. Humanoid faces and solid backgrounds work best." ,
263+ ) from e
177264 if first_coeff_path is None :
178265 raise ValueError ("Can't get the coeffs of the input" )
179266
0 commit comments