-
Notifications
You must be signed in to change notification settings - Fork 2.6k
feat: Support gemini stt model #1876
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Adding the "do-not-merge/release-note-label-needed" label because no release-note block was detected, please follow our release note process to remove it. DetailsInstructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository. |
|
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: The full list of commands accepted by this bot can be found here. DetailsNeeds approval from an approver in each of these files:Approvers can indicate their approval by writing |
| {"type": "media", 'mime_type': 'audio/mp3', "data": audio_data} | ||
| ]) | ||
| res = client.invoke([msg]) | ||
| return res.content |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The provided code looks generally well-structured for using Google Generative AI (GGAI) with OpenAI's Text-to-Speech service. However, there are a few points to consider:
-
Tokenization: The
custom_get_token_idsfunction uses a tokenizer from a module namedTokenizerManage, but it's unclear where this module or its dependencies are defined or implemented. Ensure that the necessary functions are correctly imported and accessible. -
Static Method Documentation:
@staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
It would be helpful to include type information for `model_credential` within the static method signature.
```python
@staticmethod
def new_instance(
model_type: str,
model_name: str,
model_credential: Dict[str, object],
**model_kwargs
):
-
Logging:
The current checks useprint(response_list)which should likely be replaced with logging statements for better readability and debugging support. -
Optional Parameters:
The default values ofmax_tokensandtemperatureare set in the constructor, but they might want to define default values explicitly instead of relying on keyword arguments directly. -
Method Comments: While comments do exist in the code, more detailed descriptions can help other developers understand each part of the flow better.
-
Async/Await Considerations:
If you're planning to integrate asynchronous operations, ensure all methods supporting such operations are decorated appropriately (async def) along with proper handling of async tasks (e.g., usingawait). -
Dependencies:
Make sure that all required libraries (langchain_core,typing,google-generativeai, etc.) are installed and available in your environment.
Here is an updated version of the code incorporating some these suggestions:
@@ -0,0 +1,79 @@
+import asyncio
+import io
+from typing import Dict, Optional
+import google.generativeai as genai
+from common.config.tokenizer_manage_config import TokenizerManage
+from setting.models_provider.base_model_provider import MaxKBBaseModel
+from setting.models_provider.impl.base_stt import BaseSpeechToText
+from langchain_core.messages import HumanMessage
def custom_get_token_ids(text: str) -> list[int]:
"""Converts text into token IDs using appropriate tokenization."""
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)
def convert_bytes_to_content(audio_data: bytes) -> HumanMessage:
"""
Converts binary data representing a media file into a LangChain message.
:param audio_data: Binary content of the audio file.
:return: A HumanMessage containing the media details.
"""
return HumanMessage(content=[
{'type': 'text', 'text': '把音频转成文字'},
{"type": "media", 'mime_type': 'audio/mp3', "data": audio_data}
])
class GeminiSpeechToText(MaxKBBaseModel, BaseSpeechToText):
"""A class implementing speech to text functionality using Google Generative AI."""
api_key: str
model: str
max_tokens: int = 160 # Default value if not specified
temperature: float = 0.8 # Default value if not specified
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key', '')
self.max_tokens = kwargs.get('max_tokens', self.max_tokens)
self.temperature = kwargs.get('temperature', self.temperature)
@staticmethod
def new_instance(
model_type: str,
model_name: str,
model_credential: Dict[str, object],
**model_kwargs
) -> 'GeminiSpeechToText':
additional_params = {
'max_tokens': model_kwargs.get('max_tokens') or ModelConstants.DEFAULT_GEMINI_MAX_TOKENS,
'temperature': model_kwargs.get('temperature') or ModelConstants.DEFAULT_GEMINI_TEMPERATURE
}
return GeminiSpeechToText(
model=model_name if model_name else 'gemini-pro',
api_key=model_credential.get('api_key'),
**additional_params,
)
def check_auth(self) -> bool:
try:
client = genai.ChatGoogleGenerativeAI(
model=self.model,
google_api_key=self.api_key
)
response = client.chat("你好")
# Logging success
print(f"Authentication successful via {response.status}")
return True
except Exception as e:
# Logging error
print(f"Failed authentication attempt: {repr(e)}")
return False
async def speak_and_transcribe(
self,
speech_file_path: str
) -> Optional[str]:
"""
Asynchronously transcribes the contents of a given .wav file.
This method reads the WAV file, processes it through GGAI,
and captures the transcription result.
Args:
speech_file_path(str): Path to the input .wav audio file.
Returns:
An optional string containing the transcribed text; otherwise None upon failure.
Raises:
FileNotFoundError: The wav files at path does not exist
PermissionError: Access denied to read the wav file
RuntimeError: Error communicating with server while attempting transcription.
Example usage:
await tts.speak_and_transcribe(input_sound="path/to/input.wav")
Output example:
The output will show something like:
[INFO] Authentication successful via ok
[SUCCESSFUL TRANSCRIPTION]
Hello world! How can I assist you today?
"""
try:
with open(speech_file_path, mode='rb') as f_read:
audio_data = f_read.read()
client = genai.ChatGoogleGenerativeAI(
model=self.model,
google_api_key=self.api_key
)
messages = [
{'role': 'system', 'content': ''}
,{
"role": "user",
"content":convert_bytes_to_content(audio_data),
}
]
resp = await client.generate_async(messages=[messages[1]])
transcript = ""
for chunk in resp.streamed_outputs():
for choice in chunk.choices:
if choice.message.role == "assistant":
transcript += choice.text[len(chunk.cumulative_end_index):]
# Log transcript successfully captured
print("[SUCCESSFUL TRANSCRIPTION]")
return transcript.strip("\n\r ") # Remove line breaks
except FileNotFoundError:
print("FileNotFoundError: The sound file was not found.")
raise
except PermissionError:
print("PermissionError: There was permission denied reading the sound file.")
raise
except RuntimeError as rte:
print(f"The system could't complete because of an internal processing issue:\n{rte}")
raise
except ValueError as ve:
print(f"Unknown Value Error occurred:\n{ve}")
raise
| return {**model, 'api_key': super().encryption(model.get('api_key', ''))} | ||
|
|
||
| def get_model_params_setting_form(self, model_name): | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here are my observations:
The code looks mostly consistent with common practices:
- It uses Python 3 syntax (with
# coding=utf-8), which is good. - The class and methods follow PEP 8 guidelines.
However, there are some points to consider for improvement:
- Comments: Provide more explanation of complex operations like type hinting and method logic.
- Error Handling: Consider using
logginginstead of raising exceptions in all places where error handling might be appropriate. - Code Duplication: The
encryption_dictmethod can be simplified since it simply wraps another function (super().encryption()).
Overall, the code is well-written but could benefit from better documentation and possibly refactoring to improve clarity and maintainability.
| .append_default_model_info(model_stt_info_list[0]) | ||
| .build() | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a few minor improvements and optimizations that can be made to enhance the code:
-
Remove Trailing Commas: You have trailing commas after some dictionary entries (
gemini_llm_model_credential,gemini_image_model_credential). These should be removed for consistency. -
Consistent Spacing: Ensure consistent spacing around operators (e.g.,
=) and after commas. -
Import Order: It might be helpful to sort the imports alphabetically for better readability and maintainability.
-
Variable Naming: Use meaningful variable names like
stt_modelsinstead ofmodel_stt_info_list.
Here's the revised version with these changes:
@@ -8,7 +8,6 @@
from setting.models_provider.impl.gemini_model_provider.model.stt import GeminiSpeechToText
from smartdoc.conf import PROJECT_DIR
-gemini_stt_model_credential = GeminiSTTModelCredential()
-model_info_list = [
+stt_models = [
ModelInfo('gemini-1.0-pro', '最新的Gemini 1.0 Pro模型,随Google更新而更新',
ModelTypeConst.GENERAL_MODEL,
None,
@@ -42,13 +41,18 @@
GeminiImage),
]
+
+models_stt_list = [
+ ModelInfo('gemini-1.5-flash', '最新的Gemini 1.5 Flash模型,随Google更新而更新',
+ ModelTypeConst.STT,
+ gemini_stt_model_credential,
+ GeminiSpeechToText),
+]
+
model_info_manage = (
ModelInfoManage.builder()
.append_model_info_list(model_info_list)
.append_model_info_list(stt_models)
.append_default_model_info(default_model)
.build()
)These changes make the code cleaner and more readable.
feat: Support gemini stt model