Skip to content

Commit 81cbdc2

Browse files
committed
spell_jam: refactor for switchable TTS backend
1 parent 413e195 commit 81cbdc2

File tree

4 files changed

+169
-43
lines changed

4 files changed

+169
-43
lines changed

Fruit_Jam/Fruit_Jam_Spell_Jam/aws_polly.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def text_to_speech_polly_http(
199199
text,
200200
access_key,
201201
secret_key,
202-
output_file="/saves/awspollyoutput.mp3",
202+
output_file="/saves/tts_output.mp3",
203203
voice_id="Joanna",
204204
region="us-east-1",
205205
output_format="mp3",

Fruit_Jam/Fruit_Jam_Spell_Jam/code.py

Lines changed: 19 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import os
44
import sys
55
import time
6-
76
import supervisor
7+
88
from adafruit_fruitjam import FruitJam
99
from adafruit_fruitjam.peripherals import request_display_config
1010
import adafruit_connection_manager
@@ -13,10 +13,16 @@
1313
from adafruit_bitmap_font import bitmap_font
1414
from adafruit_display_text.bitmap_label import Label
1515

16-
from aws_polly import text_to_speech_polly_http
17-
1816
from launcher_config import LauncherConfig
1917

18+
# If tts_local.py exists, use that instead of tts_aws.py
19+
try:
20+
# tts_local defines WordFetcherTTS for TTS engine running on local network server
21+
from tts_local import WordFetcherTTS
22+
except ImportError:
23+
from tts_aws import WordFetcherTTS
24+
25+
# read the user settings
2026
launcher_config = LauncherConfig()
2127

2228
# constants
@@ -60,55 +66,26 @@
6066

6167
fj.neopixels.brightness = 0.1
6268

63-
# AWS auth requires us to have accurate date/time
64-
now = fj.sync_time()
65-
66-
# setup adafruit_requests session
67-
# pylint: disable=protected-access
68-
pool = adafruit_connection_manager.get_radio_socketpool(fj.network._wifi.esp)
69-
ssl_context = adafruit_connection_manager.get_radio_ssl_context(fj.network._wifi.esp)
70-
requests = adafruit_requests.Session(pool, ssl_context)
71-
72-
# read AWS keys from settings.toml
73-
AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY")
74-
AWS_SECRET_KEY = os.getenv("AWS_SECRET_KEY")
75-
76-
77-
def fetch_word(word, voice="Joanna"):
78-
"""
79-
Fetch an MP3 saying a word from AWS Polly
80-
:param word: The word to speak
81-
:param voice: The AWS Polly voide ID to use
82-
:return: Boolean, whether the request was successful.
83-
"""
84-
85-
if AWS_ACCESS_KEY is None or AWS_SECRET_KEY is None:
86-
return False
87-
88-
fj.neopixels.fill(0xFFFF00)
89-
success = text_to_speech_polly_http(
90-
requests,
91-
text=word,
92-
access_key=AWS_ACCESS_KEY,
93-
secret_key=AWS_SECRET_KEY,
94-
voice_id=voice,
95-
)
96-
fj.neopixels.fill(0x00FF00)
97-
return success
98-
69+
word_fetcher = WordFetcherTTS(fj, launcher_config)
9970

10071
def say_and_spell_lastword():
10172
"""
10273
Say the last word, then spell it out one letter at a time, finally say it once more.
10374
"""
10475
if sayword:
105-
fj.play_mp3_file("/saves/awspollyoutput.mp3")
76+
if word_fetcher.output_path[-4:] == ".mp3":
77+
fj.play_mp3_file(word_fetcher.output_path)
78+
elif word_fetcher.output_path[-4:] == ".wav":
79+
fj.play_file(word_fetcher.output_path)
10680
time.sleep(0.2)
10781
for letter in lastword:
10882
fj.play_mp3_file(f"spell_jam_assets/letter_mp3s/{letter.upper()}.mp3")
10983
time.sleep(0.2)
11084
if sayword:
111-
fj.play_mp3_file("/saves/awspollyoutput.mp3")
85+
if word_fetcher.output_path[-4:] == ".mp3":
86+
fj.play_mp3_file(word_fetcher.output_path)
87+
elif word_fetcher.output_path[-4:] == ".wav":
88+
fj.play_file(word_fetcher.output_path)
11289
fj.neopixels.fill(0x000000)
11390

11491

@@ -133,7 +110,7 @@ def say_and_spell_lastword():
133110
elif c == "\n":
134111
if curword:
135112
lastword = curword
136-
sayword = fetch_word(lastword)
113+
sayword = word_fetcher.fetch_word(lastword)
137114
say_and_spell_lastword()
138115
curword = ""
139116
else:
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# tts_aws.py
2+
import os
3+
import adafruit_connection_manager
4+
import adafruit_requests
5+
from aws_polly import text_to_speech_polly_http
6+
7+
class WordFetcherTTS():
8+
def __init__(self, fj=None, launcher_config=None, output_path="/saves/tts_output.mp3"):
9+
10+
self.output_path = output_path
11+
self.fj = fj
12+
self.launcher_config = launcher_config
13+
14+
# AWS auth requires us to have accurate date/time
15+
now = fj.sync_time()
16+
17+
# setup adafruit_requests session
18+
pool = adafruit_connection_manager.get_radio_socketpool(fj.network._wifi.esp)
19+
ssl_context = adafruit_connection_manager.get_radio_ssl_context(fj.network._wifi.esp)
20+
self.requests = adafruit_requests.Session(pool, ssl_context)
21+
self.AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY")
22+
self.AWS_SECRET_KEY = os.getenv("AWS_SECRET_KEY")
23+
24+
def fetch_word(self, word: str, voice: str = "Joanna") -> bool:
25+
if not self.AWS_ACCESS_KEY or not self.AWS_SECRET_KEY:
26+
print("Missing AWS credentials.")
27+
return False
28+
29+
if self.fj:
30+
self.fj.neopixels.fill(0xFFFF00)
31+
32+
success = text_to_speech_polly_http(
33+
self.requests,
34+
text=word,
35+
access_key=self.AWS_ACCESS_KEY,
36+
secret_key=self.AWS_SECRET_KEY,
37+
output_file=self.output_path,
38+
voice_id=voice,
39+
region="us-east-1",
40+
output_format="mp3",
41+
)
42+
43+
if self.fj:
44+
self.fj.neopixels.fill(0x00FF00)
45+
return success
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# tts_kani.py
2+
import json
3+
import adafruit_connection_manager
4+
import adafruit_requests
5+
6+
class WordFetcherTTS():
7+
def __init__(self, fj=None, launcher_config=None, output_path="/saves/tts_output.wav"):
8+
9+
self.output_path = output_path
10+
self.launcher_config = launcher_config
11+
self.fj = fj
12+
13+
# AWS auth requires us to have accurate date/time
14+
now = fj.sync_time()
15+
16+
# setup adafruit_requests session
17+
pool = adafruit_connection_manager.get_radio_socketpool(fj.network._wifi.esp)
18+
self.requests = adafruit_requests.Session(pool)
19+
20+
def fetch_word(self, word: str, voice: str = "katie") -> bool:
21+
22+
if self.fj:
23+
self.fj.neopixels.fill(0xFFFF00)
24+
25+
audio_data = self.text_to_speech_http(
26+
text=word,
27+
voice_id=voice,
28+
)
29+
30+
success = False
31+
if audio_data:
32+
# Save to file
33+
try:
34+
with open(self.output_path, "wb") as f:
35+
f.write(audio_data)
36+
print(f"Audio saved to: {self.output_path}")
37+
success = True
38+
except Exception as e: # pylint: disable=broad-except
39+
print(f"Failed to save file: {e}")
40+
success = False
41+
else:
42+
print("Failed to synthesize speech")
43+
success = False
44+
45+
if self.fj:
46+
self.fj.neopixels.fill(0x00FF00)
47+
return success
48+
49+
def text_to_speech_http(
50+
self,
51+
text,
52+
voice_id,
53+
):
54+
"""
55+
Simple function to convert text to speech using kani-tts AI local server.py HTTP API
56+
57+
Args:
58+
text (str): Text to convert
59+
voice_id (str): voice ID
60+
61+
Returns:
62+
bool: True if successful, False otherwise
63+
"""
64+
65+
# Prepare request
66+
print(self.launcher_config.data)
67+
endpoint = ""
68+
if self.launcher_config and "spell_jam" in self.launcher_config.data:
69+
endpoint = self.launcher_config.data["spell_jam"].get("tts_server_endpoint","")
70+
if endpoint == "":
71+
print("tts_server_endpoint not configured in launcher.conf.json.")
72+
return None
73+
74+
method = "POST"
75+
uri = "/tts"
76+
77+
# Create request body
78+
request_body = {
79+
"text": f'{voice_id}: {text}',
80+
"temperature": 0.4,
81+
"max_tokens": 400,
82+
"top_p": 0.95,
83+
"chunk_size": 25,
84+
"lookback_frames": 15
85+
}
86+
payload = json.dumps(request_body)
87+
url = f"{endpoint}{uri}"
88+
headers = {"Content-Type": "application/json"}
89+
print(f"Making request to: {url}, headers: {headers}, payload: {payload}")
90+
91+
try:
92+
response = self.requests.post(url, headers=headers, data=payload)
93+
94+
if response.status_code == 200:
95+
return response.content
96+
else:
97+
print(f"Error: HTTP {response.status_code}")
98+
print(f"Response: {response.text}")
99+
return None
100+
101+
except Exception as e: # pylint: disable=broad-except
102+
print(f"Request failed: {e}")
103+
return None
104+

0 commit comments

Comments
 (0)