Skip to content

Commit 5f165b0

Browse files
committed
update interface for tts
1 parent eff8501 commit 5f165b0

File tree

3 files changed

+72
-73
lines changed

3 files changed

+72
-73
lines changed

template_langgraph/services/streamlits/pages/chat_with_tools_agent.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import tempfile
33
from base64 import b64encode
4-
from datetime import datetime
54

65
import streamlit as st
76
from audio_recorder_streamlit import audio_recorder
@@ -14,7 +13,7 @@
1413
ChatWithToolsAgent,
1514
)
1615
from template_langgraph.speeches.stt import SttWrapper
17-
from template_langgraph.speeches.tts import synthesize_audio
16+
from template_langgraph.speeches.tts import TtsWrapper
1817
from template_langgraph.tools.common import get_default_tools
1918

2019

@@ -169,35 +168,25 @@ def load_stt_wrapper(model_size: str = "base"):
169168
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio_file:
170169
temp_audio_file.write(audio_bytes)
171170
temp_audio_file_path = temp_audio_file.name
172-
st.download_button(
173-
label="🎧 録音データを保存",
174-
data=audio_bytes,
175-
file_name=f"recorded_{datetime.now().strftime('%Y%m%d_%H%M%S')}.wav",
176-
mime="audio/wav",
177-
use_container_width=True,
178-
)
179171
try:
180-
if input_output_mode == "音声":
181-
with st.spinner("音声を認識中..."):
182-
stt_wrapper = load_stt_wrapper(selected_model)
183-
language_param = None if transcription_language == "auto" else transcription_language
184-
result = stt_wrapper.transcribe(str(temp_audio_file_path), language=language_param)
185-
transcribed_text = result.get("text", "").strip()
186-
prompt_text = transcribed_text
187-
188-
if prompt_text:
189-
st.success(f"音声認識完了: {prompt_text}")
190-
prompt = prompt_text
191-
else:
192-
st.warning("音声が認識できませんでした")
172+
with st.spinner("音声を認識中..."):
173+
stt_wrapper = load_stt_wrapper(selected_model)
174+
language_param = None if transcription_language == "auto" else transcription_language
175+
transcribed_text = stt_wrapper.transcribe(str(temp_audio_file_path), language=language_param)
176+
prompt_text = transcribed_text
177+
178+
if prompt_text:
179+
st.success(f"音声認識結果: {prompt_text}")
180+
prompt = prompt_text
181+
else:
182+
st.warning("音声が認識できませんでした")
193183
except Exception as e:
194184
st.error(f"音声認識でエラーが発生しました: {e}")
195185
prompt_text = "音声入力でエラーが発生しました"
196186
finally:
197187
if os.path.exists(temp_audio_file_path):
198188
os.unlink(temp_audio_file_path)
199-
200-
else:
189+
elif input_output_mode == "テキスト":
201190
# 既存のテキスト入力モード
202191
if prompt := st.chat_input(
203192
accept_file="multiple",
@@ -210,6 +199,8 @@ def load_stt_wrapper(model_size: str = "base"):
210199
],
211200
):
212201
pass # promptは既に設定済み
202+
else:
203+
st.error("不明な入出力モードです")
213204

214205
# 共通の入力処理ロジック
215206
if prompt:
@@ -290,7 +281,7 @@ def load_stt_wrapper(model_size: str = "base"):
290281
if input_output_mode == "音声":
291282
try:
292283
with st.spinner("音声を生成中です..."):
293-
audio_bytes = synthesize_audio(
284+
audio_bytes = TtsWrapper().synthesize_audio(
294285
text=response_content,
295286
language=tts_language,
296287
speed=tts_speed,

template_langgraph/speeches/stt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,4 @@ def transcribe(
2828
audio=audio_path,
2929
language=language,
3030
)
31-
return result["text"]
31+
return result.get("text", "").strip()

template_langgraph/speeches/tts.py

Lines changed: 55 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -13,51 +13,59 @@
1313
)
1414

1515

16-
def synthesize_audio(
17-
text: str,
18-
language: str = "ja",
19-
speed: float = 1.0,
20-
pitch_shift: int = 0,
21-
volume_db: float = 0.0,
22-
) -> bytes | None:
23-
"""Convert text to speech audio using gTTS and pydub adjustments."""
24-
25-
if not text.strip():
26-
return None
27-
28-
try:
29-
tts = gTTS(text=text, lang=language)
30-
mp3_buffer = io.BytesIO()
31-
tts.write_to_fp(mp3_buffer)
32-
mp3_buffer.seek(0)
33-
34-
audio_segment = AudioSegment.from_file(mp3_buffer, format="mp3")
35-
original_rate = audio_segment.frame_rate
36-
37-
if pitch_shift != 0:
38-
semitone_ratio = 2.0 ** (pitch_shift / 12.0)
39-
shifted = audio_segment._spawn(
40-
audio_segment.raw_data,
41-
overrides={"frame_rate": int(original_rate * semitone_ratio)},
42-
)
43-
audio_segment = shifted.set_frame_rate(original_rate)
44-
45-
if speed != 1.0:
46-
if speed > 1.0:
47-
audio_segment = speedup(audio_segment, playback_speed=float(speed))
48-
else:
49-
slowed_rate = max(int(original_rate * float(speed)), 1)
50-
audio_segment = audio_segment._spawn(
16+
class TtsWrapper:
17+
def __init__(self):
18+
pass
19+
20+
def load_model(self):
21+
pass
22+
23+
def synthesize_audio(
24+
self,
25+
text: str,
26+
language: str = "ja",
27+
speed: float = 1.0,
28+
pitch_shift: int = 0,
29+
volume_db: float = 0.0,
30+
) -> bytes | None:
31+
"""Convert text to speech audio using gTTS and pydub adjustments."""
32+
33+
if not text.strip():
34+
return None
35+
36+
try:
37+
tts = gTTS(text=text, lang=language)
38+
mp3_buffer = io.BytesIO()
39+
tts.write_to_fp(mp3_buffer)
40+
mp3_buffer.seek(0)
41+
42+
audio_segment = AudioSegment.from_file(mp3_buffer, format="mp3")
43+
original_rate = audio_segment.frame_rate
44+
45+
if pitch_shift != 0:
46+
semitone_ratio = 2.0 ** (pitch_shift / 12.0)
47+
shifted = audio_segment._spawn(
5148
audio_segment.raw_data,
52-
overrides={"frame_rate": slowed_rate},
53-
).set_frame_rate(original_rate)
54-
55-
if volume_db != 0:
56-
audio_segment += float(volume_db)
57-
58-
output_buffer = io.BytesIO()
59-
audio_segment.export(output_buffer, format="mp3")
60-
return output_buffer.getvalue()
61-
except Exception as e: # pragma: no cover
62-
logger.error(f"Error in synthesize_audio: {e}")
63-
return None
49+
overrides={"frame_rate": int(original_rate * semitone_ratio)},
50+
)
51+
audio_segment = shifted.set_frame_rate(original_rate)
52+
53+
if speed != 1.0:
54+
if speed > 1.0:
55+
audio_segment = speedup(audio_segment, playback_speed=float(speed))
56+
else:
57+
slowed_rate = max(int(original_rate * float(speed)), 1)
58+
audio_segment = audio_segment._spawn(
59+
audio_segment.raw_data,
60+
overrides={"frame_rate": slowed_rate},
61+
).set_frame_rate(original_rate)
62+
63+
if volume_db != 0:
64+
audio_segment += float(volume_db)
65+
66+
output_buffer = io.BytesIO()
67+
audio_segment.export(output_buffer, format="mp3")
68+
return output_buffer.getvalue()
69+
except Exception as e: # pragma: no cover
70+
logger.error(f"Error in synthesize_audio: {e}")
71+
return None

0 commit comments

Comments
 (0)