Skip to content

Commit b69f4ad

Browse files
committed
rm xtts-streaming websocket changes
1 parent d33f22d commit b69f4ad

File tree

2 files changed

+22
-223
lines changed

2 files changed

+22
-223
lines changed

xtts-streaming/model/model.py

Lines changed: 22 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,13 @@
44
import os
55
import time
66
import wave
7-
import json
87

98
import numpy as np
109
import torch
1110
from TTS.tts.configs.xtts_config import XttsConfig
1211
from TTS.tts.models.xtts import Xtts
1312
from TTS.utils.generic_utils import get_user_data_dir
1413
from TTS.utils.manage import ModelManager
15-
import fastapi
1614

1715
# This is one of the speaker voices that comes with xtts
1816
SPEAKER_NAME = "Claribel Dervla"
@@ -35,10 +33,12 @@ def load(self):
3533
config = XttsConfig()
3634
config.load_json(os.path.join(model_path, "config.json"))
3735
self.model = Xtts.init_from_config(config)
36+
# self.model.load_checkpoint(config, checkpoint_dir=model_path, eval=True)
3837
self.model.load_checkpoint(
3938
config, checkpoint_dir=model_path, eval=True, use_deepspeed=True
4039
)
4140
self.model.to(device)
41+
# self.compiled_model = torch.compile(self.model.inference_stream)
4242

4343
self.speaker = {
4444
"speaker_embedding": self.model.speaker_manager.speakers[SPEAKER_NAME][
@@ -78,58 +78,25 @@ def wav_postprocess(self, wav):
7878
wav = (wav * 32767).astype(np.int16)
7979
return wav
8080

81+
def predict(self, model_input):
82+
text = model_input.get("text")
83+
language = model_input.get("language", "en")
84+
chunk_size = int(
85+
model_input.get("chunk_size", 20)
86+
) # Ensure chunk_size is an integer
87+
add_wav_header = False
8188

82-
async def websocket(self, websocket: fastapi.WebSocket):
83-
"""Handle WebSocket connections for text-to-speech requests"""
84-
print("WebSocket connected")
85-
try:
86-
while True:
87-
data = await websocket.receive_text()
88-
89-
try:
90-
# Parse JSON input if provided
91-
input_data = json.loads(data)
92-
except json.JSONDecodeError:
93-
# If not JSON, assume it's just text
94-
input_data = {"text": data, "language": "en", "chunk_size": 20}
95-
96-
text = input_data.get("text")
97-
language = input_data.get("language", "en")
98-
chunk_size = int(input_data.get("chunk_size", 20))
99-
100-
# Process the text to speech using the logic from the original predict method
101-
streamer = self.model.inference_stream(
102-
text,
103-
language,
104-
self.gpt_cond_latent,
105-
self.speaker_embedding,
106-
stream_chunk_size=chunk_size,
107-
enable_text_splitting=True,
108-
temperature=0.2,
109-
)
89+
streamer = self.model.inference_stream(
90+
text,
91+
language,
92+
self.gpt_cond_latent,
93+
self.speaker_embedding,
94+
stream_chunk_size=chunk_size,
95+
enable_text_splitting=True,
96+
temperature=0.2,
97+
)
11098

111-
for chunk in streamer:
112-
processed_chunk = self.wav_postprocess(chunk)
113-
processed_bytes = processed_chunk.tobytes()
114-
encoded_chunk = base64.b64encode(processed_bytes).decode('utf-8')
115-
await websocket.send_json({
116-
"type": "chunk",
117-
"data": encoded_chunk
118-
})
119-
120-
await websocket.send_json({
121-
"type": "complete",
122-
"message": f"Processed '{text}'"
123-
})
124-
125-
except fastapi.WebSocketDisconnect:
126-
print("WebSocket disconnected")
127-
except Exception as e:
128-
print(f"WebSocket error: {str(e)}")
129-
try:
130-
await websocket.send_json({
131-
"type": "error",
132-
"message": str(e)
133-
})
134-
except:
135-
pass
99+
for chunk in streamer:
100+
processed_chunk = self.wav_postprocess(chunk)
101+
processed_bytes = processed_chunk.tobytes()
102+
yield processed_bytes

xtts-streaming/test.py

Lines changed: 0 additions & 168 deletions
This file was deleted.

0 commit comments

Comments
 (0)