-
Notifications
You must be signed in to change notification settings - Fork 2.6k
feat: Support gemini stt model #1876
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| # coding=utf-8 | ||
| from typing import Dict | ||
|
|
||
| from common import forms | ||
| from common.exception.app_exception import AppApiException | ||
| from common.forms import BaseForm | ||
| from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode | ||
|
|
||
|
|
||
| class GeminiSTTModelCredential(BaseForm, BaseModelCredential): | ||
| api_key = forms.PasswordInputField('API Key', required=True) | ||
|
|
||
| def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, | ||
| raise_exception=False): | ||
| model_type_list = provider.get_model_type_list() | ||
| if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): | ||
| raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') | ||
|
|
||
| for key in ['api_key']: | ||
| if key not in model_credential: | ||
| if raise_exception: | ||
| raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') | ||
| else: | ||
| return False | ||
| try: | ||
| model = provider.get_model(model_type, model_name, model_credential) | ||
| model.check_auth() | ||
| except Exception as e: | ||
| if isinstance(e, AppApiException): | ||
| raise e | ||
| if raise_exception: | ||
| raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') | ||
| else: | ||
| return False | ||
| return True | ||
|
|
||
| def encryption_dict(self, model: Dict[str, object]): | ||
| return {**model, 'api_key': super().encryption(model.get('api_key', ''))} | ||
|
|
||
| def get_model_params_setting_form(self, model_name): | ||
| pass | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,12 +13,15 @@ | |
| ModelInfoManage | ||
| from setting.models_provider.impl.gemini_model_provider.credential.image import GeminiImageModelCredential | ||
| from setting.models_provider.impl.gemini_model_provider.credential.llm import GeminiLLMModelCredential | ||
| from setting.models_provider.impl.gemini_model_provider.credential.stt import GeminiSTTModelCredential | ||
| from setting.models_provider.impl.gemini_model_provider.model.image import GeminiImage | ||
| from setting.models_provider.impl.gemini_model_provider.model.llm import GeminiChatModel | ||
| from setting.models_provider.impl.gemini_model_provider.model.stt import GeminiSpeechToText | ||
| from smartdoc.conf import PROJECT_DIR | ||
|
|
||
| gemini_llm_model_credential = GeminiLLMModelCredential() | ||
| gemini_image_model_credential = GeminiImageModelCredential() | ||
| gemini_stt_model_credential = GeminiSTTModelCredential() | ||
|
|
||
| model_info_list = [ | ||
| ModelInfo('gemini-1.0-pro', '最新的Gemini 1.0 Pro模型,随Google更新而更新', | ||
|
|
@@ -42,14 +45,25 @@ | |
| GeminiImage), | ||
| ] | ||
|
|
||
|
|
||
| model_stt_info_list = [ | ||
| ModelInfo('gemini-1.5-flash', '最新的Gemini 1.5 Flash模型,随Google更新而更新', | ||
| ModelTypeConst.STT, | ||
| gemini_stt_model_credential, | ||
| GeminiSpeechToText), | ||
| ModelInfo('gemini-1.5-pro', '最新的Gemini 1.5 Flash模型,随Google更新而更新', | ||
| ModelTypeConst.STT, | ||
| gemini_stt_model_credential, | ||
| GeminiSpeechToText), | ||
| ] | ||
|
|
||
| model_info_manage = ( | ||
| ModelInfoManage.builder() | ||
| .append_model_info_list(model_info_list) | ||
| .append_model_info_list(model_image_info_list) | ||
| .append_model_info_list(model_stt_info_list) | ||
| .append_default_model_info(model_info_list[0]) | ||
| .append_default_model_info(model_image_info_list[0]) | ||
| .append_default_model_info(model_stt_info_list[0]) | ||
| .build() | ||
| ) | ||
|
|
||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are a few minor improvements and optimizations that can be made to enhance the code:
Here's the revised version with these changes: @@ -8,7 +8,6 @@
from setting.models_provider.impl.gemini_model_provider.model.stt import GeminiSpeechToText
from smartdoc.conf import PROJECT_DIR
-gemini_stt_model_credential = GeminiSTTModelCredential()
-model_info_list = [
+stt_models = [
ModelInfo('gemini-1.0-pro', '最新的Gemini 1.0 Pro模型,随Google更新而更新',
ModelTypeConst.GENERAL_MODEL,
None,
@@ -42,13 +41,18 @@
GeminiImage),
]
+
+models_stt_list = [
+ ModelInfo('gemini-1.5-flash', '最新的Gemini 1.5 Flash模型,随Google更新而更新',
+ ModelTypeConst.STT,
+ gemini_stt_model_credential,
+ GeminiSpeechToText),
+]
+
model_info_manage = (
ModelInfoManage.builder()
.append_model_info_list(model_info_list)
.append_model_info_list(stt_models)
.append_default_model_info(default_model)
.build()
)These changes make the code cleaner and more readable. |
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| import asyncio | ||
| import io | ||
| from typing import Dict | ||
|
|
||
| from langchain_core.messages import HumanMessage | ||
| from langchain_google_genai import ChatGoogleGenerativeAI | ||
| from openai import OpenAI | ||
|
|
||
| from common.config.tokenizer_manage_config import TokenizerManage | ||
| from setting.models_provider.base_model_provider import MaxKBBaseModel | ||
| from setting.models_provider.impl.base_stt import BaseSpeechToText | ||
| import google.generativeai as genai | ||
|
|
||
|
|
||
| def custom_get_token_ids(text: str): | ||
| tokenizer = TokenizerManage.get_tokenizer() | ||
| return tokenizer.encode(text) | ||
|
|
||
|
|
||
| class GeminiSpeechToText(MaxKBBaseModel, BaseSpeechToText): | ||
| api_key: str | ||
| model: str | ||
|
|
||
| def __init__(self, **kwargs): | ||
| super().__init__(**kwargs) | ||
| self.api_key = kwargs.get('api_key') | ||
|
|
||
| @staticmethod | ||
| def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): | ||
| optional_params = {} | ||
| if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: | ||
| optional_params['max_tokens'] = model_kwargs['max_tokens'] | ||
| if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: | ||
| optional_params['temperature'] = model_kwargs['temperature'] | ||
| return GeminiSpeechToText( | ||
| model=model_name, | ||
| api_key=model_credential.get('api_key'), | ||
| **optional_params, | ||
| ) | ||
|
|
||
| def check_auth(self): | ||
| client = ChatGoogleGenerativeAI( | ||
| model=self.model, | ||
| google_api_key=self.api_key | ||
| ) | ||
| response_list = client.invoke('你好') | ||
| # print(response_list) | ||
|
|
||
| def speech_to_text(self, audio_file): | ||
| client = ChatGoogleGenerativeAI( | ||
| model=self.model, | ||
| google_api_key=self.api_key | ||
| ) | ||
| audio_data = audio_file.read() | ||
| msg = HumanMessage(content=[ | ||
| {'type': 'text', 'text': '把音频转成文字'}, | ||
| {"type": "media", 'mime_type': 'audio/mp3', "data": audio_data} | ||
| ]) | ||
| res = client.invoke([msg]) | ||
| return res.content | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The provided code looks generally well-structured for using Google Generative AI (GGAI) with OpenAI's Text-to-Speech service. However, there are a few points to consider:
Here is an updated version of the code incorporating some these suggestions: @@ -0,0 +1,79 @@
+import asyncio
+import io
+from typing import Dict, Optional
+import google.generativeai as genai
+from common.config.tokenizer_manage_config import TokenizerManage
+from setting.models_provider.base_model_provider import MaxKBBaseModel
+from setting.models_provider.impl.base_stt import BaseSpeechToText
+from langchain_core.messages import HumanMessage
def custom_get_token_ids(text: str) -> list[int]:
"""Converts text into token IDs using appropriate tokenization."""
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)
def convert_bytes_to_content(audio_data: bytes) -> HumanMessage:
"""
Converts binary data representing a media file into a LangChain message.
:param audio_data: Binary content of the audio file.
:return: A HumanMessage containing the media details.
"""
return HumanMessage(content=[
{'type': 'text', 'text': '把音频转成文字'},
{"type": "media", 'mime_type': 'audio/mp3', "data": audio_data}
])
class GeminiSpeechToText(MaxKBBaseModel, BaseSpeechToText):
"""A class implementing speech to text functionality using Google Generative AI."""
api_key: str
model: str
max_tokens: int = 160 # Default value if not specified
temperature: float = 0.8 # Default value if not specified
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key', '')
self.max_tokens = kwargs.get('max_tokens', self.max_tokens)
self.temperature = kwargs.get('temperature', self.temperature)
@staticmethod
def new_instance(
model_type: str,
model_name: str,
model_credential: Dict[str, object],
**model_kwargs
) -> 'GeminiSpeechToText':
additional_params = {
'max_tokens': model_kwargs.get('max_tokens') or ModelConstants.DEFAULT_GEMINI_MAX_TOKENS,
'temperature': model_kwargs.get('temperature') or ModelConstants.DEFAULT_GEMINI_TEMPERATURE
}
return GeminiSpeechToText(
model=model_name if model_name else 'gemini-pro',
api_key=model_credential.get('api_key'),
**additional_params,
)
def check_auth(self) -> bool:
try:
client = genai.ChatGoogleGenerativeAI(
model=self.model,
google_api_key=self.api_key
)
response = client.chat("你好")
# Logging success
print(f"Authentication successful via {response.status}")
return True
except Exception as e:
# Logging error
print(f"Failed authentication attempt: {repr(e)}")
return False
async def speak_and_transcribe(
self,
speech_file_path: str
) -> Optional[str]:
"""
Asynchronously transcribes the contents of a given .wav file.
This method reads the WAV file, processes it through GGAI,
and captures the transcription result.
Args:
speech_file_path(str): Path to the input .wav audio file.
Returns:
An optional string containing the transcribed text; otherwise None upon failure.
Raises:
FileNotFoundError: The wav files at path does not exist
PermissionError: Access denied to read the wav file
RuntimeError: Error communicating with server while attempting transcription.
Example usage:
await tts.speak_and_transcribe(input_sound="path/to/input.wav")
Output example:
The output will show something like:
[INFO] Authentication successful via ok
[SUCCESSFUL TRANSCRIPTION]
Hello world! How can I assist you today?
"""
try:
with open(speech_file_path, mode='rb') as f_read:
audio_data = f_read.read()
client = genai.ChatGoogleGenerativeAI(
model=self.model,
google_api_key=self.api_key
)
messages = [
{'role': 'system', 'content': ''}
,{
"role": "user",
"content":convert_bytes_to_content(audio_data),
}
]
resp = await client.generate_async(messages=[messages[1]])
transcript = ""
for chunk in resp.streamed_outputs():
for choice in chunk.choices:
if choice.message.role == "assistant":
transcript += choice.text[len(chunk.cumulative_end_index):]
# Log transcript successfully captured
print("[SUCCESSFUL TRANSCRIPTION]")
return transcript.strip("\n\r ") # Remove line breaks
except FileNotFoundError:
print("FileNotFoundError: The sound file was not found.")
raise
except PermissionError:
print("PermissionError: There was permission denied reading the sound file.")
raise
except RuntimeError as rte:
print(f"The system could't complete because of an internal processing issue:\n{rte}")
raise
except ValueError as ve:
print(f"Unknown Value Error occurred:\n{ve}")
raise
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here are my observations:
The code looks mostly consistent with common practices:
# coding=utf-8), which is good.However, there are some points to consider for improvement:
logginginstead of raising exceptions in all places where error handling might be appropriate.encryption_dictmethod can be simplified since it simply wraps another function (super().encryption()).Overall, the code is well-written but could benefit from better documentation and possibly refactoring to improve clarity and maintainability.