diff --git a/README.md b/README.md index 49e76c44..c81231f2 100644 --- a/README.md +++ b/README.md @@ -17,9 +17,85 @@ Wenjiang Zhou Lyra Lab, Tencent Music Entertainment **[github](https://github.com/TMElyralab/MuseTalk)** **[huggingface](https://huggingface.co/TMElyralab/MuseTalk)** **[space](https://huggingface.co/spaces/TMElyralab/MuseTalk)** **[Technical report](https://arxiv.org/abs/2410.10122)** +**[colab](MuseTalkV15.ipynb)** We introduce `MuseTalk`, a **real-time high quality** lip-syncing model (30fps+ on an NVIDIA Tesla V100). MuseTalk can be applied with input videos, e.g., generated by [MuseV](https://github.com/TMElyralab/MuseV), as a complete virtual human solution. + +## πŸš€ What’s New In This Fork + +- **Dockerized MuseTalk Service:** Fully packaged for easy deployment as an API service or in your own workflows. +- **FastAPI Integration:** Added `fastapi_service.py` for running MuseTalk inference through a modern REST API. +- **Automated Model Downloads:** `download_models.py` script auto-downloads all required weights from Hugging Face, Google Drive, and direct URLs (including S3FD). +- **Self-Installing Dependencies:** The model downloader script will install missing Python packages automatically – clean experience for new users. +- **Submodule Friendly:** Designed to work as a submodule in larger projects. Clone, init, and run. +- **Improved Repository Structure:** All models organized under `/models/` for clear management and reproducibility. +- **Workflow Automation:** Minimal manual steps – just run one setup script, then launch your API or Docker container. + +*** + +### πŸ› οΈ How To Run this Fork + +**Step 1. Prepare Environment** +Clone the main repo and initialize submodules: +```bash +git clone --recurse-submodules https://github.com/rafipatel/MuseTalk.git +``` + + +**Step 2. If not using Docker, follow all the installation steps mentioned in [Installations](#installation) untill [download weight section](#setup-ffmpeg)** + +**Step 3. Download Models and Weights using new script in the fork (as download_weights.sh throws huggingface-cli error)** +Enter the musetalk service folder and run setup: +```bash +cd MuseTalk +python download_models.py +``` +- This will: + - Install any missing Python dependencies + - Download all required weights/checkpoints into `/models/` + +**Step 4. Run the MuseTalk API** +You can launch the FastAPI server directly: +```bash +uvicorn fastapi_service:app --host 0.0.0.0 --port 8000 +``` +Or directly via scripts.inference (change test.yaml with your audio and video): +```bash +!python -m scripts.inference --inference_config configs/inference/test.yaml --result_dir results/test --unet_model_path models/musetalkV15/unet.pth --unet_config models/musetalkV15/musetalk.json --version v15 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared/bin +``` + +Or (recommended for production) launch via Docker (after downloading files, (mounting volume to container as mentioned below)): +```bash +docker build -t musetalk . + +docker run \ + -p 8000:8000 \ + -v $(pwd)/models:/app/MuseTalk/models \ + -v $(pwd)/results:/app/MuseTalk/results \ + -v $(pwd)/data:/app/MuseTalk/data \ + -v $(pwd)/configs:/app/MuseTalk/configs \ + musetalk +``` + +**Step 4. API Usage** +Access the API at `http://localhost:8000` and POST your inference jobs. + +*** + +**Optional:** +- Edit `download_models.py` to download extra models if needed. +- Update requirements as new features are added. + +*** + +## πŸ“ Note +- For GPU support, make sure your Docker and system configuration are compatible. +- All downloads are script-automated for plug-and-play β€” no manual model shopping or setup required! +- For more details, see the commented sections in each main script. + + + ## πŸ”₯ Updates We're excited to unveil MuseTalk 1.5. This version **(1)** integrates training with perceptual loss, GAN loss, and sync loss, significantly boosting its overall performance. **(2)** We've implemented a two-stage training strategy and a spatio-temporal data sampling approach to strike a balance between visual quality and lip-sync accuracy. diff --git a/dockerfile b/dockerfile new file mode 100644 index 00000000..1f1b24b3 --- /dev/null +++ b/dockerfile @@ -0,0 +1,80 @@ +FROM python:3.10-slim-bullseye + +ENV PYTHONUNBUFFERED=1 +ENV PYTORCH_ENABLE_MPS_FALLBACK=1 + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + git \ + wget \ + ffmpeg \ + libsndfile1 \ + libgl1-mesa-glx \ + libglib2.0-0 \ + build-essential \ + curl \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +# Install PyTorch for CPU/MPS (Mac) +RUN pip3 install --no-cache-dir \ + torch==2.0.1 \ + torchvision==0.15.2 \ + torchaudio==2.0.2 + +# Clone MuseTalk +# RUN git clone https://github.com/TMElyralab/MuseTalk.git /app/MuseTalk +RUN git clone https://github.com/rafipatel/MuseTalk.git /app/MuseTalk + +WORKDIR /app/MuseTalk + +# Install requirements +RUN pip3 install --no-cache-dir -r requirements.txt +# RUN pip3 install -r requirements.txt || true + +# Install OpenMMLab packages (CPU version) +RUN pip3 install --no-cache-dir -U openmim && \ + mim install mmengine && \ + pip3 install mmcv==2.0.1 && \ + mim install "mmdet==3.1.0" && \ + mim install "mmpose==1.1.0" + +# Download model weights +# RUN python3 -m pip install huggingface_hub && \ +# python3 -c "from huggingface_hub import snapshot_download; \ +# snapshot_download(repo_id='TMElyralab/MuseTalk', local_dir='./models', allow_patterns=['models/musetalkV15/*'])" || true + +# Download additional model files + + +# RUN mkdir -p models/face-parse-bisent && \ +# pip3 install gdown && \ +# gdown --id 154JgKpzCPW82qINcVieuPH3fZ2e0P812 -O models/face-parse-bisent/79999_iter.pth && \ +# curl -L https://download.pytorch.org/models/resnet18-5c106cde.pth -o models/face-parse-bisent/resnet18-5c106cde.pth + + +# # Set working directory +# WORKDIR /app/MuseTalk + +# Create entrypoint script +# RUN echo '#!/bin/bash\n\ +# python -m scripts.inference \\\n\ +# --inference_config ${INFERENCE_CONFIG:-configs/inference/test.yaml} \\\n\ +# --result_dir ${RESULT_DIR:-results/test} \\\n\ +# --unet_model_path ${UNET_MODEL_PATH:-models/musetalkV15/unet.pth} \\\n\ +# --unet_config ${UNET_CONFIG:-models/musetalkV15/musetalk.json} \\\n\ +# --version ${VERSION:-v15} \\\n\ +# --ffmpeg_path ${FFMPEG_PATH:-/usr/bin/ffmpeg}' > /app/run_inference.sh && \ +# chmod +x /app/run_inference.sh + +# RUN chmod +x /app/MuseTalk/download_weights.sh +# RUN /app/MuseTalk/download_weights.sh + +# COPY inference.sh /app/MuseTalk/inference.sh + +# RUN chmod +x /app/MuseTalk/inference.sh +# CMD ["/app/MuseTalk/inference.sh", "v1.5","normal"] +# CMD ["uvicorn", "fastapi_service:app","--host", "0.0.0.0" ,"--port", "8000"] +CMD ["python3", "-m", "uvicorn", "fastapi_service:app", "--host", "0.0.0.0" ,"--port", "8000"] +# CMD ["/app/run_inference.sh"] diff --git a/download_models.py b/download_models.py new file mode 100644 index 00000000..a13f58a6 --- /dev/null +++ b/download_models.py @@ -0,0 +1,188 @@ +import os +import sys +import importlib +import subprocess + +# List any packages needed for downloading +REQUIRED_PACKAGES = [ + "huggingface_hub", + "gdown", + "requests" +] + +PYTHON_EXEC = sys.executable + +# Install missing packages BEFORE importing +for pkg in REQUIRED_PACKAGES: + try: + importlib.import_module(pkg.replace("-", "_")) + except ImportError: + print(f"Installing {pkg} ...") + subprocess.run([PYTHON_EXEC, "-m", "pip", "install", pkg]) + + +from huggingface_hub import hf_hub_download +import sys +import subprocess +import importlib + +# --- Configuration --- +CHECKPOINTS_DIR = "models" +HF_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co") # Use mirror if set + +# --- Directory Setup --- +DIRS = [ + "musetalkV15", "syncnet", "dwpose", + "face-parse-bisent", "sd-vae", "whisper", "musetalk" # Ensure 'musetalk' is here if V1.0 is needed +] + +for d in DIRS: + os.makedirs(os.path.join(CHECKPOINTS_DIR, d), exist_ok=True) +print(f"βœ… Created base directory: {CHECKPOINTS_DIR} and subdirectories.") + +# --- Hugging Face Downloads --- + +def download_hf_files(repo_id, filenames, subdir="", has_subpath=False): + """ + Downloads a list of files from a Hugging Face repo. + + If has_subpath is True (e.g., MuseTalk), files are downloaded relative to CHECKPOINTS_DIR. + If has_subpath is False (e.g., Whisper), files are downloaded directly into CHECKPOINTS_DIR/subdir. + """ + target_local_dir = os.path.join(CHECKPOINTS_DIR, subdir) + + # If the filename contains the directory structure (e.g., "repo_name/file.bin"), + # we need to set local_dir to CHECKPOINTS_DIR to preserve the path. + # Otherwise, we set local_dir to the final destination (target_local_dir). + final_local_dir = CHECKPOINTS_DIR if has_subpath else target_local_dir + + for filename in filenames: + print(f"Downloading {filename} from {repo_id} to {target_local_dir}...") + + # Use hf_hub_download. The output path handling is based on `has_subpath`. + hf_hub_download( + repo_id=repo_id, + filename=filename, + local_dir=final_local_dir, + endpoint=HF_ENDPOINT + ) + print(f"βœ… Finished downloading files for {repo_id} into {subdir}/.") + + +# 1. MuseTalk V1.0 & V1.5 Weights (Uses subpaths in filenames) +# NOTE: The repo files are structured like "musetalk/..." and "musetalkV15/..." +# Setting local_dir=CHECKPOINTS_DIR ensures this internal structure is preserved under "models/" + +# V1.0 Files (Target: models/musetalk) +download_hf_files( + repo_id="TMElyralab/MuseTalk", + filenames=[ + "musetalk/musetalk.json", + "musetalk/pytorch_model.bin" + ], + subdir="musetalk", + has_subpath=True # Filenames contain the subdir path +) +# V1.5 Files (Target: models/musetalkV15) +download_hf_files( + repo_id="TMElyralab/MuseTalk", + filenames=[ + "musetalkV15/musetalk.json", + "musetalkV15/unet.pth" + ], + subdir="musetalkV15", + has_subpath=True # Filenames contain the subdir path +) + +# 2. SD VAE Weights (No subpaths in filenames) +# Target: models/sd-vae/ +download_hf_files( + repo_id="stabilityai/sd-vae-ft-mse", + filenames=[ + "config.json", + "diffusion_pytorch_model.bin", + "diffusion_pytorch_model.safetensors" + ], + subdir="sd-vae", + has_subpath=False +) + +# 3. Whisper Weights (No subpaths in filenames) +# FIX: This now downloads directly into models/whisper/ +# Target: models/whisper/ +download_hf_files( + repo_id="openai/whisper-tiny", + filenames=[ + "config.json", + "pytorch_model.bin", + "preprocessor_config.json" + ], + subdir="whisper", + has_subpath=False +) + +# 4. DWPose Weights (No subpaths in filenames) +# Target: models/dwpose/ +download_hf_files( + repo_id="yzd-v/DWPose", + filenames=["dw-ll_ucoco_384.pth"], + subdir="dwpose", + has_subpath=False +) + +# 5. SyncNet Weights (No subpaths in filenames) +# Target: models/syncnet/ +download_hf_files( + repo_id="ByteDance/LatentSync", + filenames=["latentsync_syncnet.pt"], + subdir="syncnet", + has_subpath=False +) + +print("--- Hugging Face downloads complete. ---") + + + +# Download BiSeNet Face Parse Model file (from Google Drive) +try: + import gdown +except ImportError: + subprocess.run(['pip', 'install', 'gdown']) + import gdown +gdown.download( + 'https://drive.google.com/uc?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812', + os.path.join(CHECKPOINTS_DIR, "face-parse-bisent", "79999_iter.pth"), + quiet=False +) + + + + +# Download resnet18 model +import requests +url = "https://download.pytorch.org/models/resnet18-5c106cde.pth" +output_path = os.path.join(CHECKPOINTS_DIR, "face-parse-bisent", "resnet18-5c106cde.pth") +response = requests.get(url, stream=True) +with open(output_path, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) +print(f"βœ… Downloaded {url} to {output_path}") + + + +s3fd_url = "https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth" +s3fd_dest_dir = os.path.expanduser("~/.cache/torch/hub/checkpoints") +os.makedirs(s3fd_dest_dir, exist_ok=True) +s3fd_dest_path = os.path.join(s3fd_dest_dir, "s3fd-619a316812.pth") + +if not os.path.exists(s3fd_dest_path) or os.path.getsize(s3fd_dest_path) < 85_000_000: # ~85MB expected + print(f"Downloading S3FD weights from {s3fd_url} ...") + response = requests.get(s3fd_url, stream=True) + with open(s3fd_dest_path, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + print(f"βœ… Downloaded S3FD model to {s3fd_dest_path}") +else: + print(f"βœ… S3FD weights already present: {s3fd_dest_path}") + +print("--- All model downloads complete. ---") \ No newline at end of file diff --git a/fastapi_service.py b/fastapi_service.py new file mode 100644 index 00000000..7bd5eef9 --- /dev/null +++ b/fastapi_service.py @@ -0,0 +1,300 @@ +import os +import cv2 +import math +import copy +import torch +import glob +import shutil +import pickle +import numpy as np +import subprocess +import json +from fastapi import FastAPI, UploadFile, File, Form +from fastapi.responses import FileResponse, JSONResponse +from pydantic import BaseModel +from typing import Optional + +from musetalk.utils.blending import get_image +from musetalk.utils.face_parsing import FaceParsing +from musetalk.utils.audio_processor import AudioProcessor +from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model +from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder +from omegaconf import OmegaConf +from transformers import WhisperModel + +app = FastAPI(title="MuseTalk FastAPI Service") + + +def fast_check_ffmpeg(ffmpeg_path=None): + try: + subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True) + return True + except: + if ffmpeg_path: + os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}" + try: + subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True) + return True + except: + return False + return False + + +class InferenceConfig(BaseModel): + ffmpeg_path: str = "./ffmpeg-4.4-amd64-static/" + gpu_id: int = 0 + vae_type: str = "sd-vae" + unet_config: str = "./models/musetalkV15/musetalk.json" + unet_model_path: str = "./models/musetalkV15/unet.pth" + whisper_dir: str = "./models/whisper" + inference_config: str = "configs/inference/test.yaml" + bbox_shift: int = 0 + result_dir: str = './results' + extra_margin: int = 10 + fps: int = 25 + audio_padding_length_left: int = 2 + audio_padding_length_right: int = 2 + batch_size: int = 8 + output_vid_name: Optional[str] = None + use_saved_coord: bool = False + saved_coord: bool = False + use_float16: bool = False + parsing_mode: str = 'jaw' + left_cheek_width: int = 90 + right_cheek_width: int = 90 + version: str = "v15" + video_path: Optional[str] = None # <-- add this! + audio_path: Optional[str] = None + + +# @app.post("/inference") +# async def run_inference( +# video: UploadFile = File(...), +# audio: UploadFile = File(...), +# inference_params: InferenceConfig = Form(...), +# ): + +@app.post("/inference") +async def run_inference( + video: UploadFile = File(...), + audio: UploadFile = File(...), + inference_params: str = Form("{}") +): + import json + params_dict = json.loads(inference_params) + args = InferenceConfig(**params_dict) # <-- make sure InferenceConfig is imported + + # Make temp paths for uploads + video_path = f"temp_{video.filename}" + audio_path = f"temp_{audio.filename}" + with open(video_path, "wb") as f: + f.write(await video.read()) + with open(audio_path, "wb") as f: + f.write(await audio.read()) + + # Attach the file paths to the config object + args.video_path = video_path + args.audio_path = audio_path + + # Proceed with your logic using args + result_path, msg = await process_inference(args) + # Clean up temp files + os.remove(video_path) + os.remove(audio_path) + if result_path: + return FileResponse(path=result_path, filename=os.path.basename(result_path)) + else: + return JSONResponse({"error": msg or "Failed to process."}, status_code=500) + + +async def process_inference(args): + try: + print("Entering process_inference...") + if not fast_check_ffmpeg(args.ffmpeg_path): + print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed") + return None, "ffmpeg missing" + device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu") + vae, unet, pe = load_all_model( + unet_model_path=args.unet_model_path, + vae_type=args.vae_type, + unet_config=args.unet_config, + device=device + ) + timesteps = torch.tensor([0], device=device) + if args.use_float16: + pe = pe.half() + vae.vae = vae.vae.half() + unet.model = unet.model.half() + pe = pe.to(device) + vae.vae = vae.vae.to(device) + unet.model = unet.model.to(device) + audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir) + weight_dtype = unet.model.dtype + whisper = WhisperModel.from_pretrained(args.whisper_dir) + whisper = whisper.to(device=device, dtype=weight_dtype).eval() + whisper.requires_grad_(False) + # Face parser + fp = FaceParsing( + left_cheek_width=args.left_cheek_width, + right_cheek_width=args.right_cheek_width + ) if args.version == "v15" else FaceParsing() + + inference_config = OmegaConf.load(args.inference_config) + out_paths = [] + print("Loaded inference config:", inference_config) + + for task_id in inference_config: + try: + print(f"Processing task: {task_id}") + # 1. Get config for this task + task = inference_config[task_id] + video_path = args.video_path + audio_path = args.audio_path + output_vid_name = task.get("result_name", None) + + # 2. Set other params + bbox_shift = 0 if args.version == "v15" else task.get("bbox_shift", args.bbox_shift) + input_basename = os.path.basename(video_path).split('.')[0] + audio_basename = os.path.basename(audio_path).split('.')[0] + output_basename = f"{input_basename}_{audio_basename}" + temp_dir = os.path.join(args.result_dir, f"{args.version}") + os.makedirs(temp_dir, exist_ok=True) + result_img_save_path = os.path.join(temp_dir, output_basename) + crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename + ".pkl") + os.makedirs(result_img_save_path, exist_ok=True) + save_dir_full = os.path.join(temp_dir, input_basename) + + # 3. Extract frames from video + if get_file_type(video_path) == "video": + os.makedirs(save_dir_full, exist_ok=True) + cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png" + os.system(cmd) + input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]'))) + fps = get_video_fps(video_path) + elif get_file_type(video_path) == "image": + input_img_list = [video_path] + fps = args.fps + elif os.path.isdir(video_path): + input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]')) + input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) + fps = args.fps + else: + raise ValueError(f"{video_path} should be a video file, an image file or a directory of images") + + # 4. Extract audio features + whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path) + whisper_chunks = audio_processor.get_whisper_chunk( + whisper_input_features, device, weight_dtype, whisper, librosa_length, + fps=fps, + audio_padding_length_left=args.audio_padding_length_left, + audio_padding_length_right=args.audio_padding_length_right, + ) + + # 5. Preprocess input images + if os.path.exists(crop_coord_save_path) and args.use_saved_coord: + print("Using saved coordinates") + with open(crop_coord_save_path, 'rb') as f: + coord_list = pickle.load(f) + frame_list = read_imgs(input_img_list) + else: + print("Extracting landmarks... time-consuming operation") + coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) + with open(crop_coord_save_path, 'wb') as f: + pickle.dump(coord_list, f) + + print(f"Number of frames: {len(frame_list)}") + + input_latent_list = [] + for bbox, frame in zip(coord_list, frame_list): + if bbox == coord_placeholder: + continue + x1, y1, x2, y2 = bbox + if args.version == "v15": + y2 = y2 + args.extra_margin + y2 = min(y2, frame.shape[0]) + crop_frame = frame[y1:y2, x1:x2] + crop_frame = cv2.resize(crop_frame, (256,256), interpolation=cv2.INTER_LANCZOS4) + latents = vae.get_latents_for_unet(crop_frame) + input_latent_list.append(latents) + + frame_list_cycle = frame_list + frame_list[::-1] + coord_list_cycle = coord_list + coord_list[::-1] + input_latent_list_cycle = input_latent_list + input_latent_list[::-1] + + # 6. Batch inference + print("Starting inference") + video_num = len(whisper_chunks) + batch_size = args.batch_size + gen = datagen( + whisper_chunks=whisper_chunks, + vae_encode_latents=input_latent_list_cycle, + batch_size=batch_size, + delay_frame=0, + device=device, + ) + res_frame_list = [] + total = int(np.ceil(float(video_num) / batch_size)) + from tqdm import tqdm + for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total)): + audio_feature_batch = pe(whisper_batch) + latent_batch = latent_batch.to(dtype=unet.model.dtype) + pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample + recon = vae.decode_latents(pred_latents) + for res_frame in recon: + res_frame_list.append(res_frame) + + # Pad generated images + print("Padding generated images to original video size") + for i, res_frame in enumerate(tqdm(res_frame_list)): + bbox = coord_list_cycle[i%(len(coord_list_cycle))] + ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))]) + x1, y1, x2, y2 = bbox + if args.version == "v15": + y2 = y2 + args.extra_margin + y2 = min(y2, frame.shape[0]) + try: + res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1)) + except: + continue + if args.version == "v15": + combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp) + else: + combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp) + cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", combine_frame) + + # 7. Save prediction results + temp_vid_path = f"{temp_dir}/temp_{input_basename}_{audio_basename}.mp4" + cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid_path}" + print("Video generation command:", cmd_img2video) + os.system(cmd_img2video) + + if output_vid_name is None: + output_vid_name_full = os.path.join(temp_dir, output_basename + ".mp4") + else: + output_vid_name_full = os.path.join(temp_dir, output_vid_name) + cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid_path} {output_vid_name_full}" + print("Audio combination command:", cmd_combine_audio) + os.system(cmd_combine_audio) + + # Clean up temporary files for this task + shutil.rmtree(result_img_save_path) + os.remove(temp_vid_path) + shutil.rmtree(save_dir_full) + if not args.saved_coord: + os.remove(crop_coord_save_path) + + print(f"Results saved to {output_vid_name_full}") + + # APPEND OUTPUT + out_paths.append(output_vid_name_full) + except Exception as e_task: + print(f"Error in task {task_id}:", str(e_task)) + import traceback; traceback.print_exc() + + # Return result (first output video path) + return out_paths[0] if out_paths else None, "No output video produced!" + except Exception as e: + import traceback + traceback.print_exc() + print("Exception:", str(e)) + return None, str(e) diff --git a/requirements.txt b/requirements.txt index e87aa41d..3736348c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,3 +18,8 @@ imageio[ffmpeg] omegaconf ffmpeg-python moviepy + +fastapi +uvicorn +pydantic +python-multipart \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 00000000..19b13994 --- /dev/null +++ b/test.py @@ -0,0 +1,55 @@ +import requests +import json + +# URL of your FastAPI service +url = "http://localhost:8000/inference" + +# Paths to test video/audio files +video_path = "/Users/rafa/MscAi/curify/MuseTalk/data/video/sun.mp4" +audio_path = "/Users/rafa/MscAi/curify/MuseTalk/data/audio/eng.wav" + +# Inference parameters (use defaults or specify your values as needed) +inference_params = { + "ffmpeg_path": "./ffmpeg-4.4-amd64-static/", + "gpu_id": 0, + "vae_type": "sd-vae", + "unet_config": "./models/musetalkV15/musetalk.json", + "unet_model_path": "./models/musetalkV15/unet.pth", + "whisper_dir": "./models/whisper", + "inference_config": "configs/inference/test.yaml", + "bbox_shift": 0, + "result_dir": "./results", + "extra_margin": 10, + "fps": 25, + "audio_padding_length_left": 2, + "audio_padding_length_right": 2, + "batch_size": 8, + "output_vid_name": None, + "use_saved_coord": False, + "saved_coord": False, + "use_float16": False, + "parsing_mode": "jaw", + "left_cheek_width": 90, + "right_cheek_width": 90, + "version": "v15" +} +# Many can be omitted if you use defaults! + +files = { + "video": open(video_path, "rb"), + "audio": open(audio_path, "rb"), +} +# Must send JSON string for inference_params +data = { + "inference_params": json.dumps(inference_params) +} + +response = requests.post(url, files=files, data=data) + +# Save result or print error +if response.ok: + with open("result.mp4", "wb") as f: + f.write(response.content) + print("Success! Video saved as result.mp4") +else: + print(response.status_code, response.text)