Skip to content

Commit eff8501

Browse files
committed
move codes to stt module
1 parent f8bcb01 commit eff8501

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

template_langgraph/services/streamlits/pages/chat_with_tools_agent.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from datetime import datetime
55

66
import streamlit as st
7-
import whisper
87
from audio_recorder_streamlit import audio_recorder
98
from langchain_community.callbacks.streamlit import (
109
StreamlitCallbackHandler,
@@ -14,6 +13,7 @@
1413
AgentState,
1514
ChatWithToolsAgent,
1615
)
16+
from template_langgraph.speeches.stt import SttWrapper
1717
from template_langgraph.speeches.tts import synthesize_audio
1818
from template_langgraph.tools.common import get_default_tools
1919

@@ -23,10 +23,11 @@ def image_to_base64(image_bytes: bytes) -> str:
2323

2424

2525
@st.cache_resource(show_spinner=False)
26-
def load_whisper_model(model_size: str = "base"):
27-
"""Load a Whisper model only once per session."""
28-
29-
return whisper.load_model(model_size)
26+
def load_stt_wrapper(model_size: str = "base"):
27+
"""Load and cache the STT model."""
28+
stt_wrapper = SttWrapper()
29+
stt_wrapper.load_model(model_size)
30+
return stt_wrapper
3031

3132

3233
if "chat_history" not in st.session_state:
@@ -178,9 +179,9 @@ def load_whisper_model(model_size: str = "base"):
178179
try:
179180
if input_output_mode == "音声":
180181
with st.spinner("音声を認識中..."):
181-
model = load_whisper_model(selected_model)
182+
stt_wrapper = load_stt_wrapper(selected_model)
182183
language_param = None if transcription_language == "auto" else transcription_language
183-
result = model.transcribe(str(temp_audio_file_path), language=language_param)
184+
result = stt_wrapper.transcribe(str(temp_audio_file_path), language=language_param)
184185
transcribed_text = result.get("text", "").strip()
185186
prompt_text = transcribed_text
186187

template_langgraph/speeches/stt.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import logging
2+
3+
import whisper
4+
5+
from template_langgraph.loggers import get_logger
6+
7+
logger = get_logger(
8+
name=__name__,
9+
verbosity=logging.DEBUG,
10+
)
11+
12+
13+
class SttWrapper:
14+
def __init__(self):
15+
self.model = None
16+
17+
def load_model(self, model_size: str):
18+
logger.info(f"Loading Whisper model: {model_size}")
19+
self.model = whisper.load_model(model_size)
20+
21+
def transcribe(
22+
self,
23+
audio_path: str,
24+
language: str,
25+
) -> str:
26+
logger.info(f"Transcribing audio: {audio_path} with language: {language}")
27+
result = self.model.transcribe(
28+
audio=audio_path,
29+
language=language,
30+
)
31+
return result["text"]

0 commit comments

Comments
 (0)