4
4
import os
5
5
import time
6
6
import wave
7
- import json
8
7
9
8
import numpy as np
10
9
import torch
11
10
from TTS .tts .configs .xtts_config import XttsConfig
12
11
from TTS .tts .models .xtts import Xtts
13
12
from TTS .utils .generic_utils import get_user_data_dir
14
13
from TTS .utils .manage import ModelManager
15
- import fastapi
16
14
17
15
# This is one of the speaker voices that comes with xtts
18
16
SPEAKER_NAME = "Claribel Dervla"
@@ -35,10 +33,12 @@ def load(self):
35
33
config = XttsConfig ()
36
34
config .load_json (os .path .join (model_path , "config.json" ))
37
35
self .model = Xtts .init_from_config (config )
36
+ # self.model.load_checkpoint(config, checkpoint_dir=model_path, eval=True)
38
37
self .model .load_checkpoint (
39
38
config , checkpoint_dir = model_path , eval = True , use_deepspeed = True
40
39
)
41
40
self .model .to (device )
41
+ # self.compiled_model = torch.compile(self.model.inference_stream)
42
42
43
43
self .speaker = {
44
44
"speaker_embedding" : self .model .speaker_manager .speakers [SPEAKER_NAME ][
@@ -78,58 +78,25 @@ def wav_postprocess(self, wav):
78
78
wav = (wav * 32767 ).astype (np .int16 )
79
79
return wav
80
80
81
+ def predict (self , model_input ):
82
+ text = model_input .get ("text" )
83
+ language = model_input .get ("language" , "en" )
84
+ chunk_size = int (
85
+ model_input .get ("chunk_size" , 20 )
86
+ ) # Ensure chunk_size is an integer
87
+ add_wav_header = False
81
88
82
- async def websocket (self , websocket : fastapi .WebSocket ):
83
- """Handle WebSocket connections for text-to-speech requests"""
84
- print ("WebSocket connected" )
85
- try :
86
- while True :
87
- data = await websocket .receive_text ()
88
-
89
- try :
90
- # Parse JSON input if provided
91
- input_data = json .loads (data )
92
- except json .JSONDecodeError :
93
- # If not JSON, assume it's just text
94
- input_data = {"text" : data , "language" : "en" , "chunk_size" : 20 }
95
-
96
- text = input_data .get ("text" )
97
- language = input_data .get ("language" , "en" )
98
- chunk_size = int (input_data .get ("chunk_size" , 20 ))
99
-
100
- # Process the text to speech using the logic from the original predict method
101
- streamer = self .model .inference_stream (
102
- text ,
103
- language ,
104
- self .gpt_cond_latent ,
105
- self .speaker_embedding ,
106
- stream_chunk_size = chunk_size ,
107
- enable_text_splitting = True ,
108
- temperature = 0.2 ,
109
- )
89
+ streamer = self .model .inference_stream (
90
+ text ,
91
+ language ,
92
+ self .gpt_cond_latent ,
93
+ self .speaker_embedding ,
94
+ stream_chunk_size = chunk_size ,
95
+ enable_text_splitting = True ,
96
+ temperature = 0.2 ,
97
+ )
110
98
111
- for chunk in streamer :
112
- processed_chunk = self .wav_postprocess (chunk )
113
- processed_bytes = processed_chunk .tobytes ()
114
- encoded_chunk = base64 .b64encode (processed_bytes ).decode ('utf-8' )
115
- await websocket .send_json ({
116
- "type" : "chunk" ,
117
- "data" : encoded_chunk
118
- })
119
-
120
- await websocket .send_json ({
121
- "type" : "complete" ,
122
- "message" : f"Processed '{ text } '"
123
- })
124
-
125
- except fastapi .WebSocketDisconnect :
126
- print ("WebSocket disconnected" )
127
- except Exception as e :
128
- print (f"WebSocket error: { str (e )} " )
129
- try :
130
- await websocket .send_json ({
131
- "type" : "error" ,
132
- "message" : str (e )
133
- })
134
- except :
135
- pass
99
+ for chunk in streamer :
100
+ processed_chunk = self .wav_postprocess (chunk )
101
+ processed_bytes = processed_chunk .tobytes ()
102
+ yield processed_bytes
0 commit comments