1010import torch
1111import torch .nn .functional as F
1212import torchaudio
13- from demucs import pretrained
14- from demucs .apply import apply_model
1513from fastapi import HTTPException , UploadFile
1614from model import DialectClassifier
1715from starlette .middleware .cors import CORSMiddleware
@@ -31,7 +29,6 @@ def setup(self, device: str):
3129 map_location = device ,
3230 )
3331 self .model .eval ()
34- self .demucs_model = pretrained .get_model ("htdemucs" )
3532 self .transform = A .Compose (
3633 [
3734 A .AddGaussianNoise (p = 0.2 ),
@@ -54,13 +51,8 @@ def setup(self, device: str):
5451 def separate_audio (self , audio_path : str ) -> str :
5552 wav , sr = torchaudio .load (audio_path )
5653 wav = wav .mean (dim = 0 , keepdim = True )
57-
58- with torch .no_grad ():
59- sources = apply_model (self .demucs_model , wav .unsqueeze (0 ), shifts = 1 )
60-
61- vocals = sources [0 , 3 ]
6254 separated_path = f"/tmp/separated_{ time .time ()} .wav"
63- torchaudio .save (separated_path , vocals , sr )
55+ torchaudio .save (uri = separated_path , src = wav , sample_rate = sr )
6456 return separated_path
6557
6658 def remove_silence (self , audio : np .ndarray , sr : int ) -> np .ndarray :
0 commit comments