Skip to content

Commit 2e25ff2

Browse files
committed
use OpenAI Whisper for STT, instead of using Azure
1 parent 7a7bc18 commit 2e25ff2

File tree

3 files changed

+387
-70
lines changed

3 files changed

+387
-70
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ dependencies = [
2727
"langgraph>=0.6.2",
2828
"langgraph-supervisor>=0.0.29",
2929
"mlflow>=3.4.0",
30+
"openai-whisper>=20250625",
3031
"openai[realtime]>=1.98.0",
3132
"opentelemetry-api>=1.36.0",
3233
"opentelemetry-exporter-otlp>=1.36.0",

template_langgraph/services/streamlits/pages/chat_with_tools_agent.py

Lines changed: 66 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
from base64 import b64encode
1+
import os
22
import tempfile
3-
from os import getenv
3+
from base64 import b64encode
4+
from datetime import datetime
45

56
import streamlit as st
7+
import whisper
68
from audio_recorder_streamlit import audio_recorder
79
from gtts import gTTS
810
from langchain_community.callbacks.streamlit import (
911
StreamlitCallbackHandler,
1012
)
11-
from langchain_community.document_loaders.parsers.audio import AzureOpenAIWhisperParser
12-
from langchain_core.documents.base import Blob
1313

1414
from template_langgraph.agents.chat_with_tools_agent.agent import (
1515
AgentState,
@@ -22,52 +22,64 @@ def image_to_base64(image_bytes: bytes) -> str:
2222
return b64encode(image_bytes).decode("utf-8")
2323

2424

25+
@st.cache_resource(show_spinner=False)
26+
def load_whisper_model(model_size: str = "base"):
27+
"""Load a Whisper model only once per session."""
28+
29+
return whisper.load_model(model_size)
30+
31+
2532
if "chat_history" not in st.session_state:
2633
st.session_state["chat_history"] = []
2734

2835
# Sidebar: 入出力モード選択、ツール選択とエージェントの構築
2936
with st.sidebar:
3037
st.subheader("入出力モード")
31-
38+
3239
# 入出力モード選択
3340
if "input_output_mode" not in st.session_state:
3441
st.session_state["input_output_mode"] = "テキスト"
35-
42+
3643
input_output_mode = st.radio(
3744
"モードを選択してください",
3845
options=["テキスト", "音声"],
3946
index=0 if st.session_state["input_output_mode"] == "テキスト" else 1,
40-
help="テキスト: 従来のテキスト入力/出力, 音声: マイク入力/音声出力"
47+
help="テキスト: 従来のテキスト入力/出力, 音声: マイク入力/音声出力",
4148
)
4249
st.session_state["input_output_mode"] = input_output_mode
43-
44-
# 音声モードの場合、Azure OpenAI設定を表示
50+
51+
# 音声モードの場合、Whisper 設定を表示
4552
if input_output_mode == "音声":
4653
st.subheader("音声認識設定 (オプション)")
47-
with st.expander("Azure OpenAI Whisper設定", expanded=False):
48-
azure_openai_endpoint = st.text_input(
49-
"AZURE_OPENAI_ENDPOINT",
50-
value=getenv("AZURE_OPENAI_ENDPOINT", ""),
51-
help="Azure OpenAI リソースのエンドポイント"
54+
with st.expander("Whisper設定", expanded=False):
55+
selected_model = st.sidebar.selectbox(
56+
"Whisperモデル",
57+
[
58+
"tiny",
59+
"base",
60+
"small",
61+
"medium",
62+
"large",
63+
],
64+
index=1,
5265
)
53-
azure_openai_api_key = st.text_input(
54-
"AZURE_OPENAI_API_KEY",
55-
value=getenv("AZURE_OPENAI_API_KEY", ""),
56-
type="password",
57-
help="Azure OpenAI リソースのAPIキー"
66+
transcription_language = st.sidebar.selectbox(
67+
"文字起こし言語",
68+
[
69+
"auto",
70+
"ja",
71+
"en",
72+
],
73+
index=0,
74+
help="autoは言語自動判定です",
5875
)
59-
azure_openai_api_version = st.text_input(
60-
"AZURE_OPENAI_API_VERSION",
61-
value=getenv("AZURE_OPENAI_API_VERSION", "2024-02-01"),
62-
help="Azure OpenAI APIバージョン"
76+
st.markdown(
77+
"""
78+
- Whisperモデルは大きいほど高精度ですが、処理に時間がかかります。
79+
- 文字起こし言語を指定することで、認識精度が向上します。
80+
"""
6381
)
64-
azure_openai_model_stt = st.text_input(
65-
"AZURE_OPENAI_MODEL_STT",
66-
value=getenv("AZURE_OPENAI_MODEL_STT", "whisper"),
67-
help="音声認識用のデプロイ名"
68-
)
69-
st.caption("※設定しない場合は、音声入力時にプレースホルダーテキストが使用されます")
70-
82+
7183
st.divider()
7284
st.subheader("使用するツール")
7385

@@ -121,60 +133,47 @@ def image_to_base64(image_bytes: bytes) -> str:
121133
audio_bytes = audio_recorder(
122134
text="クリックして録音",
123135
recording_color="red",
124-
neutral_color="black",
136+
neutral_color="gray",
125137
icon_name="microphone",
126138
icon_size="2x",
127-
key="audio_input"
139+
key="audio_input",
128140
)
129-
141+
130142
if audio_bytes:
131143
st.audio(audio_bytes, format="audio/wav")
132-
144+
133145
# 音声データを一時ファイルに保存
134146
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio_file:
135147
temp_audio_file.write(audio_bytes)
136148
temp_audio_file_path = temp_audio_file.name
137-
138-
# Azure OpenAI Whisperが設定されている場合は音声認識を実施
149+
st.download_button(
150+
label="🎧 録音データを保存",
151+
data=audio_bytes,
152+
file_name=f"recorded_{datetime.now().strftime('%Y%m%d_%H%M%S')}.wav",
153+
mime="audio/wav",
154+
use_container_width=True,
155+
)
139156
try:
140-
if (input_output_mode == "音声" and
141-
azure_openai_endpoint and azure_openai_api_key and
142-
azure_openai_model_stt):
143-
157+
if input_output_mode == "音声":
144158
with st.spinner("音声を認識中..."):
145-
audio_blob = Blob(path=temp_audio_file_path)
146-
parser = AzureOpenAIWhisperParser(
147-
api_key=azure_openai_api_key,
148-
azure_endpoint=azure_openai_endpoint,
149-
api_version=azure_openai_api_version,
150-
deployment_name=azure_openai_model_stt,
151-
)
152-
documents = parser.lazy_parse(blob=audio_blob)
153-
results = [doc.page_content for doc in documents]
154-
prompt_text = "\n".join(results).strip()
155-
159+
model = load_whisper_model(selected_model)
160+
language_param = None if transcription_language == "auto" else transcription_language
161+
result = model.transcribe(str(temp_audio_file_path), language=language_param)
162+
transcribed_text = result.get("text", "").strip()
163+
prompt_text = transcribed_text
164+
156165
if prompt_text:
157166
st.success(f"音声認識完了: {prompt_text}")
158167
prompt = prompt_text
159168
else:
160169
st.warning("音声が認識できませんでした")
161-
prompt = None
162-
else:
163-
# Azure OpenAI設定がない場合はプレースホルダー
164-
prompt_text = "音声入力を受信しました(音声認識設定が必要です)"
165-
prompt = prompt_text
166-
st.info("音声認識を使用するには、サイドバーでAzure OpenAI設定を入力してください")
167-
168170
except Exception as e:
169171
st.error(f"音声認識でエラーが発生しました: {e}")
170172
prompt_text = "音声入力でエラーが発生しました"
171-
prompt = prompt_text
172173
finally:
173-
# 一時ファイルを削除
174-
import os
175174
if os.path.exists(temp_audio_file_path):
176175
os.unlink(temp_audio_file_path)
177-
176+
178177
else:
179178
# 既存のテキスト入力モード
180179
if prompt := st.chat_input(
@@ -259,27 +258,24 @@ def image_to_base64(image_bytes: bytes) -> str:
259258
)
260259
last_message = response["messages"][-1]
261260
st.session_state["chat_history"].append(last_message)
262-
261+
263262
# レスポンス表示とオーディオ出力
264263
response_content = last_message.content
265264
st.write(response_content)
266-
265+
267266
# 音声モードの場合、音声出力を追加
268267
if input_output_mode == "音声":
269268
try:
270269
# gTTSを使って音声生成
271-
tts = gTTS(text=response_content, lang='ja')
270+
tts = gTTS(text=response_content, lang="ja")
272271
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_audio_file:
273272
tts.save(temp_audio_file.name)
274-
273+
275274
# 音声ファイルを読み込んでstreamlit audio widgetで再生
276275
with open(temp_audio_file.name, "rb") as audio_file:
277276
audio_bytes = audio_file.read()
278277
st.audio(audio_bytes, format="audio/mp3", autoplay=True)
279-
280-
# 一時ファイルを削除
281-
import os
282278
os.unlink(temp_audio_file.name)
283-
279+
284280
except Exception as e:
285281
st.warning(f"音声出力でエラーが発生しました: {e}")

0 commit comments

Comments
 (0)