Skip to content

Commit d33f22d

Browse files
committed
add sesame-csm-1b example
1 parent fc07ffb commit d33f22d

File tree

5 files changed

+283
-22
lines changed

5 files changed

+283
-22
lines changed

sesame-csm-1b/config.yaml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
model_name: sesame-csm-1b
2+
python_version: py310
3+
model_metadata:
4+
example_model_input:
5+
text: "Hello from Sesame."
6+
speaker: 0
7+
requirements:
8+
- torch==2.4.0
9+
- torchaudio==2.4.0
10+
- tokenizers==0.21.0
11+
- transformers==4.49.0
12+
- huggingface_hub==0.28.1
13+
- moshi==0.2.2
14+
- torchtune==0.4.0
15+
- torchao==0.9.0
16+
- silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master
17+
- ffmpeg
18+
- git+https://github.com/veerbia/csm.git
19+
resources:
20+
accelerator: T4
21+
cpu: '1'
22+
memory: 10Gi
23+
use_gpu: true
24+
secrets:
25+
hf_access_token: null
26+
system_packages: []
27+
environment_variables: {}
28+
external_package_dirs: []

sesame-csm-1b/model/__init__.py

Whitespace-only changes.

sesame-csm-1b/model/model.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import base64
2+
from io import BytesIO
3+
from huggingface_hub import hf_hub_download
4+
from generator import load_csm_1b
5+
import torchaudio
6+
import torch
7+
8+
class Model:
9+
def __init__(self, **kwargs):
10+
self.generator = None
11+
self._secrets = kwargs["secrets"]
12+
13+
def load(self):
14+
model_path = hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt", token=self._secrets["hf_access_token"])
15+
self.generator = load_csm_1b(model_path, "cuda", self._secrets["hf_access_token"])
16+
17+
def wav_to_base64(self, wav_tensor):
18+
buffer = BytesIO()
19+
torchaudio.save(buffer, wav_tensor.unsqueeze(0).cpu(), self.generator.sample_rate, format="wav")
20+
buffer.seek(0)
21+
return base64.b64encode(buffer.read()).decode("utf-8")
22+
23+
def predict(self, model_input):
24+
text = model_input.get("text", "Hello from Sesame.")
25+
speaker = model_input.get("speaker", 0)
26+
audio = self.generator.generate(
27+
text=text,
28+
speaker=speaker,
29+
context=[],
30+
max_audio_length_ms=10_000,
31+
)
32+
return {"output": self.wav_to_base64(audio)}

xtts-streaming/model/model.py

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
import os
55
import time
66
import wave
7+
import json
78

89
import numpy as np
910
import torch
1011
from TTS.tts.configs.xtts_config import XttsConfig
1112
from TTS.tts.models.xtts import Xtts
1213
from TTS.utils.generic_utils import get_user_data_dir
1314
from TTS.utils.manage import ModelManager
15+
import fastapi
1416

1517
# This is one of the speaker voices that comes with xtts
1618
SPEAKER_NAME = "Claribel Dervla"
@@ -33,12 +35,10 @@ def load(self):
3335
config = XttsConfig()
3436
config.load_json(os.path.join(model_path, "config.json"))
3537
self.model = Xtts.init_from_config(config)
36-
# self.model.load_checkpoint(config, checkpoint_dir=model_path, eval=True)
3738
self.model.load_checkpoint(
3839
config, checkpoint_dir=model_path, eval=True, use_deepspeed=True
3940
)
4041
self.model.to(device)
41-
# self.compiled_model = torch.compile(self.model.inference_stream)
4242

4343
self.speaker = {
4444
"speaker_embedding": self.model.speaker_manager.speakers[SPEAKER_NAME][
@@ -78,25 +78,58 @@ def wav_postprocess(self, wav):
7878
wav = (wav * 32767).astype(np.int16)
7979
return wav
8080

81-
def predict(self, model_input):
82-
text = model_input.get("text")
83-
language = model_input.get("language", "en")
84-
chunk_size = int(
85-
model_input.get("chunk_size", 20)
86-
) # Ensure chunk_size is an integer
87-
add_wav_header = False
8881

89-
streamer = self.model.inference_stream(
90-
text,
91-
language,
92-
self.gpt_cond_latent,
93-
self.speaker_embedding,
94-
stream_chunk_size=chunk_size,
95-
enable_text_splitting=True,
96-
temperature=0.2,
97-
)
82+
async def websocket(self, websocket: fastapi.WebSocket):
83+
"""Handle WebSocket connections for text-to-speech requests"""
84+
print("WebSocket connected")
85+
try:
86+
while True:
87+
data = await websocket.receive_text()
88+
89+
try:
90+
# Parse JSON input if provided
91+
input_data = json.loads(data)
92+
except json.JSONDecodeError:
93+
# If not JSON, assume it's just text
94+
input_data = {"text": data, "language": "en", "chunk_size": 20}
95+
96+
text = input_data.get("text")
97+
language = input_data.get("language", "en")
98+
chunk_size = int(input_data.get("chunk_size", 20))
99+
100+
# Process the text to speech using the logic from the original predict method
101+
streamer = self.model.inference_stream(
102+
text,
103+
language,
104+
self.gpt_cond_latent,
105+
self.speaker_embedding,
106+
stream_chunk_size=chunk_size,
107+
enable_text_splitting=True,
108+
temperature=0.2,
109+
)
98110

99-
for chunk in streamer:
100-
processed_chunk = self.wav_postprocess(chunk)
101-
processed_bytes = processed_chunk.tobytes()
102-
yield processed_bytes
111+
for chunk in streamer:
112+
processed_chunk = self.wav_postprocess(chunk)
113+
processed_bytes = processed_chunk.tobytes()
114+
encoded_chunk = base64.b64encode(processed_bytes).decode('utf-8')
115+
await websocket.send_json({
116+
"type": "chunk",
117+
"data": encoded_chunk
118+
})
119+
120+
await websocket.send_json({
121+
"type": "complete",
122+
"message": f"Processed '{text}'"
123+
})
124+
125+
except fastapi.WebSocketDisconnect:
126+
print("WebSocket disconnected")
127+
except Exception as e:
128+
print(f"WebSocket error: {str(e)}")
129+
try:
130+
await websocket.send_json({
131+
"type": "error",
132+
"message": str(e)
133+
})
134+
except:
135+
pass

xtts-streaming/test.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import asyncio
2+
import websockets
3+
import json
4+
import base64
5+
import wave
6+
import os
7+
8+
def wav_to_base64(wav_path):
9+
"""Convert a WAV file to base64 encoded string"""
10+
with open(wav_path, "rb") as wav_file:
11+
return base64.b64encode(wav_file.read()).decode('utf-8')
12+
13+
async def send_websocket_data():
14+
# Connection details
15+
uri = "wss://model-rwn1jgd3.api.baseten.co/v1/websocket"
16+
headers = {"Authorization": "Api-Key vVolDAU0.Mbynm8M7VGnaGqLbW9pwfWxFePNrGw8G"}
17+
18+
async with websockets.connect(uri, extra_headers=headers) as websocket:
19+
# For the TTS model, we send text instead of audio
20+
text_data = {
21+
"text": "Hello, this is a test of the text to speech websocket API.",
22+
"language": "en",
23+
"chunk_size": 20
24+
}
25+
26+
# Send the text data as JSON
27+
await websocket.send(json.dumps(text_data))
28+
print(f"Sent text: {text_data['text']}")
29+
30+
# Collect audio chunks
31+
audio_chunks = []
32+
33+
# Process responses
34+
while True:
35+
try:
36+
response = await websocket.recv()
37+
38+
# Try to parse as JSON
39+
try:
40+
data = json.loads(response)
41+
print(f"Received response: {data.get('type', 'unknown')}")
42+
43+
if data.get("type") == "chunk":
44+
# Decode and save the audio chunk
45+
audio_chunk = base64.b64decode(data["data"])
46+
audio_chunks.append(audio_chunk)
47+
print("Saved audio chunk")
48+
49+
elif data.get("type") == "complete":
50+
print(f"Processing complete: {data.get('message')}")
51+
break
52+
53+
elif data.get("type") == "error":
54+
print(f"Error: {data.get('message')}")
55+
break
56+
57+
except json.JSONDecodeError:
58+
# Not JSON, print the first part
59+
print(f"Received non-JSON response: {response[:50]}...")
60+
break
61+
62+
except Exception as e:
63+
print(f"Error receiving data: {str(e)}")
64+
break
65+
66+
# Save the audio to a WAV file if we received chunks
67+
if audio_chunks:
68+
output_file = "tts_output.wav"
69+
with wave.open(output_file, 'wb') as wf:
70+
wf.setnchannels(1) # Mono
71+
wf.setsampwidth(2) # 16-bit
72+
wf.setframerate(24000) # XTTS default sample rate
73+
wf.writeframes(b''.join(audio_chunks))
74+
75+
print(f"Audio saved to {output_file}")
76+
print(f"Full path: {os.path.abspath(output_file)}")
77+
else:
78+
print("No audio data received")
79+
80+
async def test_multiple_concurrent_requests():
81+
"""Test sending multiple concurrent requests to the TTS websocket API"""
82+
83+
async def single_request(idx):
84+
"""Handle a single request with unique text and output file"""
85+
output_file = f"tts_output_{idx}.wav"
86+
text = f"This is concurrent test number {idx}."
87+
88+
try:
89+
# Connection details
90+
uri = "wss://model-rwn1jgd3.api.baseten.co/v1/websocket"
91+
headers = {"Authorization": "Api-Key vVolDAU0.Mbynm8M7VGnaGqLbW9pwfWxFePNrGw8G"}
92+
93+
async with websockets.connect(uri, extra_headers=headers) as websocket:
94+
# Send text data as JSON
95+
text_data = {
96+
"text": text,
97+
"language": "en",
98+
"chunk_size": 20
99+
}
100+
101+
await websocket.send(json.dumps(text_data))
102+
print(f"Request {idx}: Sent text: {text}")
103+
104+
# Collect audio chunks
105+
audio_chunks = []
106+
107+
# Process responses
108+
while True:
109+
try:
110+
response = await websocket.recv()
111+
112+
# Try to parse as JSON
113+
try:
114+
data = json.loads(response)
115+
116+
if data.get("type") == "chunk":
117+
# Decode and save the audio chunk
118+
audio_chunk = base64.b64decode(data["data"])
119+
audio_chunks.append(audio_chunk)
120+
121+
elif data.get("type") == "complete":
122+
print(f"Request {idx}: Processing complete")
123+
break
124+
125+
elif data.get("type") == "error":
126+
print(f"Request {idx}: Error: {data.get('message')}")
127+
return False
128+
129+
except json.JSONDecodeError:
130+
print(f"Request {idx}: Received non-JSON response")
131+
return False
132+
133+
except Exception as e:
134+
print(f"Request {idx}: Error receiving data: {str(e)}")
135+
return False
136+
137+
# Save the audio to a WAV file if we received chunks
138+
if audio_chunks:
139+
with wave.open(output_file, 'wb') as wf:
140+
wf.setnchannels(1) # Mono
141+
wf.setsampwidth(2) # 16-bit
142+
wf.setframerate(24000) # XTTS default sample rate
143+
wf.writeframes(b''.join(audio_chunks))
144+
145+
print(f"Request {idx}: Audio saved to {output_file}")
146+
return True
147+
else:
148+
print(f"Request {idx}: No audio data received")
149+
return False
150+
151+
except Exception as e:
152+
print(f"Request {idx}: Failed with exception: {str(e)}")
153+
return False
154+
155+
num_requests = 4
156+
157+
print(f"Starting {num_requests} concurrent requests...")
158+
results = await asyncio.gather(*[single_request(i+1) for i in range(num_requests)])
159+
160+
successful = results.count(True)
161+
print(f"Completed {successful} out of {num_requests} requests successfully")
162+
return successful == num_requests
163+
164+
# Run the tests
165+
if __name__ == "__main__":
166+
asyncio.run(send_websocket_data())
167+
print("\n--- Testing multiple concurrent requests ---\n")
168+
asyncio.run(test_multiple_concurrent_requests())

0 commit comments

Comments
 (0)