Skip to content

Commit ad364d2

Browse files
committed
support audio & normalize API
1 parent 3d69b74 commit ad364d2

File tree

3 files changed

+208
-6
lines changed

3 files changed

+208
-6
lines changed

src/rest/rest.py

Lines changed: 146 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from fastapi import FastAPI, APIRouter, HTTPException
21
from http import HTTPStatus
32

3+
from fastapi import FastAPI, APIRouter, HTTPException
4+
45
from src.api.api import (
56
Namespace,
67
CreateNamespaceResponse,
@@ -13,16 +14,17 @@
1314
DeleteFilesRequest,
1415
ListDirectoryResponse,
1516
)
16-
from src.service.audio import AudioService
17+
from src.logger import logger
18+
from src.service.audio import AudioUVR5Params, AudioSlicerParams, AudioASRParams, AudioService, AudioDenoiseParams, AudioRefinementSubmitParams, AudioRefinementDeleteParams
1719
from src.service.file import FileService
1820
from src.service.namespace import NamespaceService
21+
from src.service.normalize import NormalizeService, NormalizeParams
22+
from src.service.session import SessionManager, session_guard
1923
from src.service.train import TrainGPTService, TrainSovitsService
2024
from src.service.voice import VoiceCloneService
21-
from src.service.session import SessionManager, session_guard
2225
from src.train.gpt import GPTTrainParams
2326
from src.train.sovits import SovitsTrainParams
2427
from src.utils.response import EaseVoiceResponse
25-
from src.logger import logger
2628

2729

2830
class NamespaceAPI:
@@ -287,7 +289,7 @@ def _register_routes(self):
287289
self.router.post("/train/gpt")(self.train_gpt)
288290
self.router.post("/train/sovits")(self.train_sovits)
289291

290-
def train_gpt(self, params: GPTTrainParams):
292+
async def train_gpt(self, params: GPTTrainParams):
291293
result = self._do_train_gpt(params)
292294

293295
# session_guard wrapper return a dict
@@ -298,7 +300,7 @@ def train_gpt(self, params: GPTTrainParams):
298300

299301
async def train_sovits(self, params: SovitsTrainParams):
300302
result = self._do_train_sovits(params)
301-
303+
302304
# session_guard wrapper return a dict
303305
if isinstance(result, EaseVoiceResponse):
304306
return result
@@ -316,6 +318,138 @@ def _do_train_sovits(self, params: SovitsTrainParams):
316318
return service.train()
317319

318320

321+
class NormalizeAPI:
322+
def __init__(self):
323+
self.router = APIRouter()
324+
self._register_routes()
325+
326+
def _register_routes(self):
327+
self.router.post("/normalize")(self.normalize)
328+
329+
async def normalize(self, request: NormalizeParams):
330+
result = self._do_normalize(request)
331+
332+
# session_guard wrapper return a dict
333+
if isinstance(result, EaseVoiceResponse):
334+
return result
335+
logger.error(f"failed to normalize: {result}")
336+
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=result)
337+
338+
@session_guard("Normalize")
339+
def _do_normalize(self, params: NormalizeParams):
340+
service = NormalizeService(params.processing_path)
341+
return service.normalize()
342+
343+
344+
class AudioAPI:
345+
def __init__(self):
346+
self.router = APIRouter()
347+
self._register_routes()
348+
349+
def _register_routes(self):
350+
self.router.post("/audio/uvr5")(self.audio_uvr5)
351+
self.router.post("/audio/slicer")(self.audio_slicer)
352+
self.router.post("/audio/denoise")(self.audio_denoise)
353+
self.router.post("/audio/asr")(self.audio_asr)
354+
self.router.get("/audio/refinement")(self.list_audio_refinement)
355+
self.router.post("/audio/refinement")(self.update_audio_refinement)
356+
self.router.delete("/audio/refinement")(self.delete_audio_refinement)
357+
358+
async def audio_uvr5(self, request: AudioUVR5Params):
359+
result = self._do_audio_uvr5(request)
360+
361+
# session_guard wrapper return a dict
362+
if isinstance(result, EaseVoiceResponse):
363+
return result
364+
logger.error(f"failed to uvr5: {result}")
365+
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=result)
366+
367+
async def audio_slicer(self, request: AudioSlicerParams):
368+
result = self._do_audio_slicer(request)
369+
370+
# session_guard wrapper return a dict
371+
if isinstance(result, EaseVoiceResponse):
372+
return result
373+
logger.error(f"failed to slice: {result}")
374+
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=result)
375+
376+
async def audio_denoise(self, request: AudioDenoiseParams):
377+
result = self._do_audio_denoise(request)
378+
379+
# session_guard wrapper return a dict
380+
if isinstance(result, EaseVoiceResponse):
381+
return result
382+
logger.error(f"failed to denoise: {result}")
383+
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=result)
384+
385+
async def audio_asr(self, request: AudioASRParams):
386+
result = self._do_audio_asr(request)
387+
388+
# session_guard wrapper return a dict
389+
if isinstance(result, EaseVoiceResponse):
390+
return result
391+
logger.error(f"failed to asr: {result}")
392+
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=result)
393+
394+
async def list_audio_refinement(self, input_dir: str, output_dir: str):
395+
service = AudioService(source_dir=input_dir, output_dir=output_dir)
396+
result = service.refinement_reload()
397+
if isinstance(result, EaseVoiceResponse):
398+
return result
399+
logger.error(f"failed to list audio refinement: {result}")
400+
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=result)
401+
402+
def update_audio_refinement(self, request: AudioRefinementSubmitParams):
403+
service = AudioService(source_dir=request.source_dir, output_dir=request.output_dir)
404+
result = service.refinement_submit_text(request.index, request.language, request.text_content)
405+
if isinstance(result, EaseVoiceResponse):
406+
return result
407+
logger.error(f"failed to update audio refinement: {result}")
408+
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=result)
409+
410+
def delete_audio_refinement(self, params: AudioRefinementDeleteParams):
411+
service = AudioService(source_dir=params.source_dir, output_dir=params.output_dir)
412+
result = service.refinement_delete_text(params.file_index)
413+
if isinstance(result, EaseVoiceResponse):
414+
return result
415+
logger.error(f"failed to delete audio refinement: {result}")
416+
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=result)
417+
418+
@session_guard("AudioUVR5")
419+
def _do_audio_uvr5(self, params: AudioUVR5Params):
420+
service = AudioService(source_dir=params.source_dir, output_dir=params.output_dir)
421+
return service.uvr5(params.model_name, params.audio_format)
422+
423+
@session_guard("AudioSlicer")
424+
def _do_audio_slicer(self, params: AudioSlicerParams):
425+
service = AudioService(source_dir=params.source_dir, output_dir=params.output_dir)
426+
return service.slicer(
427+
threshold=params.threshold,
428+
min_length=params.min_length,
429+
min_interval=params.min_interval,
430+
hop_size=params.hop_size,
431+
max_silent_kept=params.max_silent_kept,
432+
normalize_max=params.normalize_max,
433+
alpha_mix=params.alpha_mix,
434+
num_process=params.num_process,
435+
)
436+
437+
@session_guard("AudioDenoise")
438+
def _do_audio_denoise(self, params: AudioDenoiseParams):
439+
service = AudioService(source_dir=params.source_dir, output_dir=params.output_dir)
440+
return service.denoise()
441+
442+
@session_guard("AudioASR")
443+
def _do_audio_asr(self, params: AudioASRParams):
444+
service = AudioService(source_dir=params.source_dir, output_dir=params.output_dir)
445+
return service.asr(
446+
asr_model=params.asr_model,
447+
model_size=params.model_size,
448+
language=params.language,
449+
precision=params.precision,
450+
)
451+
452+
319453
# Initialize FastAPI and NamespaceService
320454
app = FastAPI()
321455

@@ -336,3 +470,9 @@ def _do_train_sovits(self, params: SovitsTrainParams):
336470

337471
train_api = TrainAPI()
338472
app.include_router(train_api.router, prefix="/apis/v1")
473+
474+
normalize_api = NormalizeAPI()
475+
app.include_router(normalize_api.router, prefix="/apis/v1")
476+
477+
audio_api = AudioAPI()
478+
app.include_router(audio_api.router, prefix="/apis/v1")

src/service/audio.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import traceback
66
from multiprocessing import Process
7+
from dataclasses import dataclass
78

89
import ffmpeg
910
import numpy as np
@@ -21,6 +22,60 @@
2122
from src.utils.response import ResponseStatus, EaseVoiceResponse
2223

2324

25+
@dataclass
26+
class AudioUVR5Params:
27+
source_dir: str
28+
output_dir: str
29+
model_name: str
30+
audio_format: str
31+
32+
33+
@dataclass
34+
class AudioSlicerParams:
35+
source_dir: str
36+
output_dir: str
37+
threshold: int = -34
38+
min_length: int = 4000
39+
min_interval: int = 300
40+
hop_size: int = 10
41+
max_silent_kept: int = 500
42+
normalize_max: float = 0.9
43+
alpha_mix: float = 0.25
44+
num_process: int = 4
45+
46+
47+
@dataclass
48+
class AudioDenoiseParams:
49+
source_dir: str
50+
output_dir: str
51+
52+
53+
@dataclass
54+
class AudioASRParams:
55+
source_dir: str
56+
output_dir: str
57+
asr_model: str = "funasr"
58+
model_size: str = "large"
59+
language: str = "zh"
60+
precision: str = "float32"
61+
62+
63+
@dataclass
64+
class AudioRefinementSubmitParams:
65+
source_dir: str
66+
output_dir: str
67+
index: str
68+
language: str
69+
text_content: str
70+
71+
72+
@dataclass
73+
class AudioRefinementDeleteParams:
74+
source_dir: str
75+
output_dir: str
76+
file_index: str
77+
78+
2479
class AudioService():
2580
def __init__(self, source_dir: str, output_dir: str):
2681
super().__init__()

src/service/normalize.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
#!/usr/bin/env python
22
# -*- encoding=utf8 -*-
33

4+
from dataclasses import dataclass
5+
46
from src.normalization.normalize import Normalize
57
from src.utils.response import EaseVoiceResponse, ResponseStatus
68

79

10+
@dataclass
11+
class NormalizeParams:
12+
processing_path: str
13+
14+
815
class NormalizeService(object):
916
def __init__(self, processing_path: str):
1017
self.processing_path = processing_path

0 commit comments

Comments
 (0)