Skip to content

Commit f69d583

Browse files
committed
audio 出力設定を追加
1 parent 2e25ff2 commit f69d583

File tree

3 files changed

+126
-39
lines changed

3 files changed

+126
-39
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ dependencies = [
3434
"opentelemetry-sdk>=1.36.0",
3535
"psycopg2-binary>=2.9.10",
3636
"pydantic-settings>=2.9.1",
37+
"pydub>=0.25.1",
3738
"pypdf>=5.9.0",
3839
"python-dotenv>=1.1.0",
3940
"qdrant-client>=1.15.1",

template_langgraph/services/streamlits/pages/chat_with_tools_agent.py

Lines changed: 114 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import io
12
import os
23
import tempfile
34
from base64 import b64encode
@@ -10,6 +11,8 @@
1011
from langchain_community.callbacks.streamlit import (
1112
StreamlitCallbackHandler,
1213
)
14+
from pydub import AudioSegment
15+
from pydub.effects import speedup
1316

1417
from template_langgraph.agents.chat_with_tools_agent.agent import (
1518
AgentState,
@@ -29,6 +32,56 @@ def load_whisper_model(model_size: str = "base"):
2932
return whisper.load_model(model_size)
3033

3134

35+
def synthesize_audio(
36+
text: str,
37+
language: str = "ja",
38+
speed: float = 1.0,
39+
pitch_shift: int = 0,
40+
volume_db: float = 0.0,
41+
) -> bytes | None:
42+
"""Convert text to speech audio using gTTS and pydub adjustments."""
43+
44+
if not text.strip():
45+
return None
46+
47+
try:
48+
tts = gTTS(text=text, lang=language)
49+
mp3_buffer = io.BytesIO()
50+
tts.write_to_fp(mp3_buffer)
51+
mp3_buffer.seek(0)
52+
53+
audio_segment = AudioSegment.from_file(mp3_buffer, format="mp3")
54+
original_rate = audio_segment.frame_rate
55+
56+
if pitch_shift != 0:
57+
semitone_ratio = 2.0 ** (pitch_shift / 12.0)
58+
shifted = audio_segment._spawn(
59+
audio_segment.raw_data,
60+
overrides={"frame_rate": int(original_rate * semitone_ratio)},
61+
)
62+
audio_segment = shifted.set_frame_rate(original_rate)
63+
64+
if speed != 1.0:
65+
if speed > 1.0:
66+
audio_segment = speedup(audio_segment, playback_speed=float(speed))
67+
else:
68+
slowed_rate = max(int(original_rate * float(speed)), 1)
69+
audio_segment = audio_segment._spawn(
70+
audio_segment.raw_data,
71+
overrides={"frame_rate": slowed_rate},
72+
).set_frame_rate(original_rate)
73+
74+
if volume_db != 0:
75+
audio_segment += float(volume_db)
76+
77+
output_buffer = io.BytesIO()
78+
audio_segment.export(output_buffer, format="mp3")
79+
return output_buffer.getvalue()
80+
except Exception as exc: # pragma: no cover
81+
st.error(f"音声合成に失敗しました: {exc}")
82+
return None
83+
84+
3285
if "chat_history" not in st.session_state:
3386
st.session_state["chat_history"] = []
3487

@@ -51,34 +104,58 @@ def load_whisper_model(model_size: str = "base"):
51104
# 音声モードの場合、Whisper 設定を表示
52105
if input_output_mode == "音声":
53106
st.subheader("音声認識設定 (オプション)")
54-
with st.expander("Whisper設定", expanded=False):
55-
selected_model = st.sidebar.selectbox(
56-
"Whisperモデル",
57-
[
58-
"tiny",
59-
"base",
60-
"small",
61-
"medium",
62-
"large",
63-
],
64-
index=1,
65-
)
66-
transcription_language = st.sidebar.selectbox(
67-
"文字起こし言語",
68-
[
69-
"auto",
70-
"ja",
71-
"en",
72-
],
73-
index=0,
74-
help="autoは言語自動判定です",
75-
)
76-
st.markdown(
77-
"""
78-
- Whisperモデルは大きいほど高精度ですが、処理に時間がかかります。
79-
- 文字起こし言語を指定することで、認識精度が向上します。
80-
"""
81-
)
107+
selected_model = st.sidebar.selectbox(
108+
"Whisperモデル",
109+
[
110+
"tiny",
111+
"base",
112+
"small",
113+
"medium",
114+
"large",
115+
],
116+
index=1,
117+
)
118+
transcription_language = st.sidebar.selectbox(
119+
"文字起こし言語",
120+
[
121+
"auto",
122+
"ja",
123+
"en",
124+
],
125+
index=0,
126+
help="autoは言語自動判定です",
127+
)
128+
tts_language = st.sidebar.selectbox(
129+
"TTS言語",
130+
[
131+
"ja",
132+
"en",
133+
"fr",
134+
"de",
135+
"ko",
136+
"zh-CN",
137+
],
138+
index=0,
139+
)
140+
tts_speed = st.sidebar.slider(
141+
"再生速度",
142+
min_value=0.5,
143+
max_value=2.0,
144+
step=0.1,
145+
value=1.0,
146+
)
147+
tts_pitch = st.sidebar.slider(
148+
"ピッチ (半音)",
149+
min_value=-12,
150+
max_value=12,
151+
value=0,
152+
)
153+
tts_volume = st.sidebar.slider(
154+
"音量 (dB)",
155+
min_value=-20,
156+
max_value=10,
157+
value=0,
158+
)
82159

83160
st.divider()
84161
st.subheader("使用するツール")
@@ -266,16 +343,14 @@ def load_whisper_model(model_size: str = "base"):
266343
# 音声モードの場合、音声出力を追加
267344
if input_output_mode == "音声":
268345
try:
269-
# gTTSを使って音声生成
270-
tts = gTTS(text=response_content, lang="ja")
271-
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_audio_file:
272-
tts.save(temp_audio_file.name)
273-
274-
# 音声ファイルを読み込んでstreamlit audio widgetで再生
275-
with open(temp_audio_file.name, "rb") as audio_file:
276-
audio_bytes = audio_file.read()
277-
st.audio(audio_bytes, format="audio/mp3", autoplay=True)
278-
os.unlink(temp_audio_file.name)
279-
346+
with st.spinner("音声を生成中です..."):
347+
audio_bytes = synthesize_audio(
348+
text=response_content,
349+
language=tts_language,
350+
speed=tts_speed,
351+
pitch_shift=tts_pitch,
352+
volume_db=tts_volume,
353+
)
354+
st.audio(audio_bytes, format="audio/mp3", autoplay=True)
280355
except Exception as e:
281356
st.warning(f"音声出力でエラーが発生しました: {e}")

uv.lock

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)