1
+ import io
2
+ import os
3
+ import tempfile
1
4
from base64 import b64encode
5
+ from datetime import datetime
2
6
3
7
import streamlit as st
8
+ import whisper
9
+ from audio_recorder_streamlit import audio_recorder
10
+ from gtts import gTTS
4
11
from langchain_community .callbacks .streamlit import (
5
12
StreamlitCallbackHandler ,
6
13
)
14
+ from pydub import AudioSegment
15
+ from pydub .effects import speedup
7
16
8
17
from template_langgraph .agents .chat_with_tools_agent .agent import (
9
18
AgentState ,
@@ -16,11 +25,147 @@ def image_to_base64(image_bytes: bytes) -> str:
16
25
return b64encode (image_bytes ).decode ("utf-8" )
17
26
18
27
28
+ @st .cache_resource (show_spinner = False )
29
+ def load_whisper_model (model_size : str = "base" ):
30
+ """Load a Whisper model only once per session."""
31
+
32
+ return whisper .load_model (model_size )
33
+
34
+
35
+ def synthesize_audio (
36
+ text : str ,
37
+ language : str = "ja" ,
38
+ speed : float = 1.0 ,
39
+ pitch_shift : int = 0 ,
40
+ volume_db : float = 0.0 ,
41
+ ) -> bytes | None :
42
+ """Convert text to speech audio using gTTS and pydub adjustments."""
43
+
44
+ if not text .strip ():
45
+ return None
46
+
47
+ try :
48
+ tts = gTTS (text = text , lang = language )
49
+ mp3_buffer = io .BytesIO ()
50
+ tts .write_to_fp (mp3_buffer )
51
+ mp3_buffer .seek (0 )
52
+
53
+ audio_segment = AudioSegment .from_file (mp3_buffer , format = "mp3" )
54
+ original_rate = audio_segment .frame_rate
55
+
56
+ if pitch_shift != 0 :
57
+ semitone_ratio = 2.0 ** (pitch_shift / 12.0 )
58
+ shifted = audio_segment ._spawn (
59
+ audio_segment .raw_data ,
60
+ overrides = {"frame_rate" : int (original_rate * semitone_ratio )},
61
+ )
62
+ audio_segment = shifted .set_frame_rate (original_rate )
63
+
64
+ if speed != 1.0 :
65
+ if speed > 1.0 :
66
+ audio_segment = speedup (audio_segment , playback_speed = float (speed ))
67
+ else :
68
+ slowed_rate = max (int (original_rate * float (speed )), 1 )
69
+ audio_segment = audio_segment ._spawn (
70
+ audio_segment .raw_data ,
71
+ overrides = {"frame_rate" : slowed_rate },
72
+ ).set_frame_rate (original_rate )
73
+
74
+ if volume_db != 0 :
75
+ audio_segment += float (volume_db )
76
+
77
+ output_buffer = io .BytesIO ()
78
+ audio_segment .export (output_buffer , format = "mp3" )
79
+ return output_buffer .getvalue ()
80
+ except Exception as exc : # pragma: no cover
81
+ st .error (f"音声合成に失敗しました: { exc } " )
82
+ return None
83
+
84
+
19
85
if "chat_history" not in st .session_state :
20
86
st .session_state ["chat_history" ] = []
21
87
22
- # Sidebar: ツール選択とエージェントの構築
88
+ # Sidebar: 入出力モード選択、 ツール選択とエージェントの構築
23
89
with st .sidebar :
90
+ st .subheader ("入出力モード" )
91
+
92
+ # 入出力モード選択
93
+ if "input_output_mode" not in st .session_state :
94
+ st .session_state ["input_output_mode" ] = "テキスト"
95
+
96
+ input_output_mode = st .radio (
97
+ "モードを選択してください" ,
98
+ options = ["テキスト" , "音声" ],
99
+ index = 0 if st .session_state ["input_output_mode" ] == "テキスト" else 1 ,
100
+ help = "テキスト: 従来のテキスト入力/出力, 音声: マイク入力/音声出力" ,
101
+ )
102
+ st .session_state ["input_output_mode" ] = input_output_mode
103
+
104
+ # 音声モードの場合、Whisper 設定を表示
105
+ if input_output_mode == "音声" :
106
+ st .subheader ("音声認識設定 (オプション)" )
107
+ audio_bytes = audio_recorder (
108
+ text = "クリックして音声入力👉️" ,
109
+ recording_color = "red" ,
110
+ neutral_color = "gray" ,
111
+ icon_name = "microphone" ,
112
+ icon_size = "2x" ,
113
+ key = "audio_input" ,
114
+ )
115
+ selected_model = st .sidebar .selectbox (
116
+ "Whisperモデル" ,
117
+ [
118
+ "tiny" ,
119
+ "base" ,
120
+ "small" ,
121
+ "medium" ,
122
+ "large" ,
123
+ ],
124
+ index = 1 ,
125
+ )
126
+ transcription_language = st .sidebar .selectbox (
127
+ "文字起こし言語" ,
128
+ [
129
+ "auto" ,
130
+ "ja" ,
131
+ "en" ,
132
+ ],
133
+ index = 0 ,
134
+ help = "autoは言語自動判定です" ,
135
+ )
136
+ tts_language = st .sidebar .selectbox (
137
+ "TTS言語" ,
138
+ [
139
+ "ja" ,
140
+ "en" ,
141
+ "fr" ,
142
+ "de" ,
143
+ "ko" ,
144
+ "zh-CN" ,
145
+ ],
146
+ index = 0 ,
147
+ )
148
+ tts_speed = st .sidebar .slider (
149
+ "再生速度" ,
150
+ min_value = 0.5 ,
151
+ max_value = 2.0 ,
152
+ step = 0.1 ,
153
+ value = 1.0 ,
154
+ )
155
+ tts_pitch = st .sidebar .slider (
156
+ "ピッチ (半音)" ,
157
+ min_value = - 12 ,
158
+ max_value = 12 ,
159
+ value = 0 ,
160
+ )
161
+ tts_volume = st .sidebar .slider (
162
+ "音量 (dB)" ,
163
+ min_value = - 20 ,
164
+ max_value = 10 ,
165
+ value = 0 ,
166
+ )
167
+
168
+ st .divider ()
24
169
st .subheader ("使用するツール" )
25
170
26
171
# 利用可能なツール一覧を取得
@@ -63,16 +208,63 @@ def image_to_base64(image_bytes: bytes) -> str:
63
208
else :
64
209
st .chat_message ("assistant" ).write (msg .content )
65
210
66
- if prompt := st .chat_input (
67
- accept_file = "multiple" ,
68
- file_type = [
69
- "png" ,
70
- "jpg" ,
71
- "jpeg" ,
72
- "gif" ,
73
- "webp" ,
74
- ],
75
- ):
211
+ # 入力セクション: モードに応じて分岐
212
+ prompt = None
213
+ prompt_text = ""
214
+ prompt_files = []
215
+
216
+ if input_output_mode == "音声" :
217
+ if audio_bytes :
218
+ st .audio (audio_bytes , format = "audio/wav" )
219
+
220
+ # 音声データを一時ファイルに保存
221
+ with tempfile .NamedTemporaryFile (suffix = ".wav" , delete = False ) as temp_audio_file :
222
+ temp_audio_file .write (audio_bytes )
223
+ temp_audio_file_path = temp_audio_file .name
224
+ st .download_button (
225
+ label = "🎧 録音データを保存" ,
226
+ data = audio_bytes ,
227
+ file_name = f"recorded_{ datetime .now ().strftime ('%Y%m%d_%H%M%S' )} .wav" ,
228
+ mime = "audio/wav" ,
229
+ use_container_width = True ,
230
+ )
231
+ try :
232
+ if input_output_mode == "音声" :
233
+ with st .spinner ("音声を認識中..." ):
234
+ model = load_whisper_model (selected_model )
235
+ language_param = None if transcription_language == "auto" else transcription_language
236
+ result = model .transcribe (str (temp_audio_file_path ), language = language_param )
237
+ transcribed_text = result .get ("text" , "" ).strip ()
238
+ prompt_text = transcribed_text
239
+
240
+ if prompt_text :
241
+ st .success (f"音声認識完了: { prompt_text } " )
242
+ prompt = prompt_text
243
+ else :
244
+ st .warning ("音声が認識できませんでした" )
245
+ except Exception as e :
246
+ st .error (f"音声認識でエラーが発生しました: { e } " )
247
+ prompt_text = "音声入力でエラーが発生しました"
248
+ finally :
249
+ if os .path .exists (temp_audio_file_path ):
250
+ os .unlink (temp_audio_file_path )
251
+
252
+ else :
253
+ # 既存のテキスト入力モード
254
+ if prompt := st .chat_input (
255
+ accept_file = "multiple" ,
256
+ file_type = [
257
+ "png" ,
258
+ "jpg" ,
259
+ "jpeg" ,
260
+ "gif" ,
261
+ "webp" ,
262
+ ],
263
+ ):
264
+ pass # promptは既に設定済み
265
+
266
+ # 共通の入力処理ロジック
267
+ if prompt :
76
268
user_display_items = []
77
269
message_parts = []
78
270
@@ -141,4 +333,22 @@ def image_to_base64(image_bytes: bytes) -> str:
141
333
)
142
334
last_message = response ["messages" ][- 1 ]
143
335
st .session_state ["chat_history" ].append (last_message )
144
- st .write (last_message .content )
336
+
337
+ # レスポンス表示とオーディオ出力
338
+ response_content = last_message .content
339
+ st .write (response_content )
340
+
341
+ # 音声モードの場合、音声出力を追加
342
+ if input_output_mode == "音声" :
343
+ try :
344
+ with st .spinner ("音声を生成中です..." ):
345
+ audio_bytes = synthesize_audio (
346
+ text = response_content ,
347
+ language = tts_language ,
348
+ speed = tts_speed ,
349
+ pitch_shift = tts_pitch ,
350
+ volume_db = tts_volume ,
351
+ )
352
+ st .audio (audio_bytes , format = "audio/mp3" , autoplay = True )
353
+ except Exception as e :
354
+ st .warning (f"音声出力でエラーが発生しました: { e } " )
0 commit comments