Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# coding=utf-8
from typing import Dict

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode


class GeminiSTTModelCredential(BaseForm, BaseModelCredential):
api_key = forms.PasswordInputField('API Key', required=True)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')

for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.check_auth()
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True

def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}

def get_model_params_setting_form(self, model_name):
pass
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are my observations:

The code looks mostly consistent with common practices:

  • It uses Python 3 syntax (with # coding=utf-8), which is good.
  • The class and methods follow PEP 8 guidelines.

However, there are some points to consider for improvement:

  1. Comments: Provide more explanation of complex operations like type hinting and method logic.
  2. Error Handling: Consider using logging instead of raising exceptions in all places where error handling might be appropriate.
  3. Code Duplication: The encryption_dict method can be simplified since it simply wraps another function (super().encryption()).

Overall, the code is well-written but could benefit from better documentation and possibly refactoring to improve clarity and maintainability.

Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
ModelInfoManage
from setting.models_provider.impl.gemini_model_provider.credential.image import GeminiImageModelCredential
from setting.models_provider.impl.gemini_model_provider.credential.llm import GeminiLLMModelCredential
from setting.models_provider.impl.gemini_model_provider.credential.stt import GeminiSTTModelCredential
from setting.models_provider.impl.gemini_model_provider.model.image import GeminiImage
from setting.models_provider.impl.gemini_model_provider.model.llm import GeminiChatModel
from setting.models_provider.impl.gemini_model_provider.model.stt import GeminiSpeechToText
from smartdoc.conf import PROJECT_DIR

gemini_llm_model_credential = GeminiLLMModelCredential()
gemini_image_model_credential = GeminiImageModelCredential()
gemini_stt_model_credential = GeminiSTTModelCredential()

model_info_list = [
ModelInfo('gemini-1.0-pro', '最新的Gemini 1.0 Pro模型,随Google更新而更新',
Expand All @@ -42,14 +45,25 @@
GeminiImage),
]


model_stt_info_list = [
ModelInfo('gemini-1.5-flash', '最新的Gemini 1.5 Flash模型,随Google更新而更新',
ModelTypeConst.STT,
gemini_stt_model_credential,
GeminiSpeechToText),
ModelInfo('gemini-1.5-pro', '最新的Gemini 1.5 Flash模型,随Google更新而更新',
ModelTypeConst.STT,
gemini_stt_model_credential,
GeminiSpeechToText),
]

model_info_manage = (
ModelInfoManage.builder()
.append_model_info_list(model_info_list)
.append_model_info_list(model_image_info_list)
.append_model_info_list(model_stt_info_list)
.append_default_model_info(model_info_list[0])
.append_default_model_info(model_image_info_list[0])
.append_default_model_info(model_stt_info_list[0])
.build()
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few minor improvements and optimizations that can be made to enhance the code:

  1. Remove Trailing Commas: You have trailing commas after some dictionary entries (gemini_llm_model_credential, gemini_image_model_credential). These should be removed for consistency.

  2. Consistent Spacing: Ensure consistent spacing around operators (e.g., =) and after commas.

  3. Import Order: It might be helpful to sort the imports alphabetically for better readability and maintainability.

  4. Variable Naming: Use meaningful variable names like stt_models instead of model_stt_info_list.

Here's the revised version with these changes:

@@ -8,7 +8,6 @@
 from setting.models_provider.impl.gemini_model_provider.model.stt import GeminiSpeechToText
 from smartdoc.conf import PROJECT_DIR

-gemini_stt_model_credential = GeminiSTTModelCredential()

-model_info_list = [
+stt_models = [
     ModelInfo('gemini-1.0-pro', '最新的Gemini 1.0 Pro模型,随Google更新而更新',
              ModelTypeConst.GENERAL_MODEL,
              None,
@@ -42,13 +41,18 @@
               GeminiImage),
 ]

+
+models_stt_list = [
+    ModelInfo('gemini-1.5-flash', '最新的Gemini 1.5 Flash模型,随Google更新而更新',
+              ModelTypeConst.STT,
+              gemini_stt_model_credential,
+              GeminiSpeechToText),
+]

+
 model_info_manage = (
     ModelInfoManage.builder()
         .append_model_info_list(model_info_list)
         .append_model_info_list(stt_models)
         .append_default_model_info(default_model)
         .build()
 )

These changes make the code cleaner and more readable.

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import asyncio
import io
from typing import Dict

from langchain_core.messages import HumanMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from openai import OpenAI

from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_stt import BaseSpeechToText
import google.generativeai as genai


def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)


class GeminiSpeechToText(MaxKBBaseModel, BaseSpeechToText):
api_key: str
model: str

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return GeminiSpeechToText(
model=model_name,
api_key=model_credential.get('api_key'),
**optional_params,
)

def check_auth(self):
client = ChatGoogleGenerativeAI(
model=self.model,
google_api_key=self.api_key
)
response_list = client.invoke('你好')
# print(response_list)

def speech_to_text(self, audio_file):
client = ChatGoogleGenerativeAI(
model=self.model,
google_api_key=self.api_key
)
audio_data = audio_file.read()
msg = HumanMessage(content=[
{'type': 'text', 'text': '把音频转成文字'},
{"type": "media", 'mime_type': 'audio/mp3', "data": audio_data}
])
res = client.invoke([msg])
return res.content
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided code looks generally well-structured for using Google Generative AI (GGAI) with OpenAI's Text-to-Speech service. However, there are a few points to consider:

  1. Tokenization: The custom_get_token_ids function uses a tokenizer from a module named TokenizerManage, but it's unclear where this module or its dependencies are defined or implemented. Ensure that the necessary functions are correctly imported and accessible.

  2. Static Method Documentation:

     @staticmethod
     def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
It would be helpful to include type information for `model_credential` within the static method signature.
```python
    @staticmethod
    def new_instance(
        model_type: str,
        model_name: str,
        model_credential: Dict[str, object],
        **model_kwargs
    ):
  1. Logging:
    The current checks use print(response_list) which should likely be replaced with logging statements for better readability and debugging support.

  2. Optional Parameters:
    The default values of max_tokens and temperature are set in the constructor, but they might want to define default values explicitly instead of relying on keyword arguments directly.

  3. Method Comments: While comments do exist in the code, more detailed descriptions can help other developers understand each part of the flow better.

  4. Async/Await Considerations:
    If you're planning to integrate asynchronous operations, ensure all methods supporting such operations are decorated appropriately (async def) along with proper handling of async tasks (e.g., using await).

  5. Dependencies:
    Make sure that all required libraries (langchain_core, typing, google-generativeai, etc.) are installed and available in your environment.

Here is an updated version of the code incorporating some these suggestions:

@@ -0,0 +1,79 @@
+import asyncio
+import io
+from typing import Dict, Optional

+import google.generativeai as genai

+from common.config.tokenizer_manage_config import TokenizerManage
+from setting.models_provider.base_model_provider import MaxKBBaseModel
+from setting.models_provider.impl.base_stt import BaseSpeechToText
+from langchain_core.messages import HumanMessage


def custom_get_token_ids(text: str) -> list[int]:
    """Converts text into token IDs using appropriate tokenization."""
    tokenizer = TokenizerManage.get_tokenizer()
    return tokenizer.encode(text)

def convert_bytes_to_content(audio_data: bytes) -> HumanMessage:
    """
    Converts binary data representing a media file into a LangChain message.

    :param audio_data: Binary content of the audio file.
    :return: A HumanMessage containing the media details.
    """
    return HumanMessage(content=[
        {'type': 'text', 'text': '把音频转成文字'},
        {"type": "media", 'mime_type': 'audio/mp3', "data": audio_data}
    ])

class GeminiSpeechToText(MaxKBBaseModel, BaseSpeechToText):
    """A class implementing speech to text functionality using Google Generative AI."""

    api_key: str
    model: str
    max_tokens: int = 160  # Default value if not specified
    temperature: float = 0.8  # Default value if not specified

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.api_key = kwargs.get('api_key', '')
        self.max_tokens = kwargs.get('max_tokens', self.max_tokens)
        self.temperature = kwargs.get('temperature', self.temperature)

    @staticmethod
    def new_instance(
            model_type: str,
            model_name: str,
            model_credential: Dict[str, object],
            **model_kwargs
    ) -> 'GeminiSpeechToText':
        additional_params = {
            'max_tokens': model_kwargs.get('max_tokens') or ModelConstants.DEFAULT_GEMINI_MAX_TOKENS,
            'temperature': model_kwargs.get('temperature') or ModelConstants.DEFAULT_GEMINI_TEMPERATURE
        }
        return GeminiSpeechToText(
            model=model_name if model_name else 'gemini-pro',
            api_key=model_credential.get('api_key'),
            **additional_params,
        )

    def check_auth(self) -> bool:
        try:
            client = genai.ChatGoogleGenerativeAI(
                model=self.model,
                google_api_key=self.api_key
            )
            response = client.chat("你好")
            # Logging success
            print(f"Authentication successful via {response.status}")
            return True
        except Exception as e:
            # Logging error
            print(f"Failed authentication attempt: {repr(e)}")
            return False

    async def speak_and_transcribe(
            self,
            speech_file_path: str
    ) -> Optional[str]:
        """
        Asynchronously transcribes the contents of a given .wav file.
        
        This method reads the WAV file, processes it through GGAI, 
        and captures the transcription result.

        Args:
          speech_file_path(str): Path to the input .wav audio file.

        Returns:
          An optional string containing the transcribed text; otherwise None upon failure.

        Raises:
          FileNotFoundError: The wav files at path does not exist
		  PermissionError: Access denied to read the wav file
		  RuntimeError: Error communicating with server while attempting transcription.

        Example usage:
          
          await tts.speak_and_transcribe(input_sound="path/to/input.wav")

        Output example:

          The output will show something like:

              [INFO] Authentication successful via ok
  
              [SUCCESSFUL TRANSCRIPTION]
               Hello world! How can I assist you today?
           """
  
        try:
            with open(speech_file_path, mode='rb') as f_read:
                audio_data = f_read.read()

            client = genai.ChatGoogleGenerativeAI(
                model=self.model,
                google_api_key=self.api_key
            )
            
            messages = [
                  {'role': 'system', 'content': ''}
                 ,{
                    "role": "user",
                    "content":convert_bytes_to_content(audio_data),
                }
            ]
            
            resp = await client.generate_async(messages=[messages[1]])
            transcript = ""
            for chunk in resp.streamed_outputs():
                for choice in chunk.choices:
                    if choice.message.role == "assistant":
                        transcript += choice.text[len(chunk.cumulative_end_index):]

            # Log transcript successfully captured
            print("[SUCCESSFUL TRANSCRIPTION]")
            return transcript.strip("\n\r ")  # Remove line breaks
            
       except FileNotFoundError:
            print("FileNotFoundError: The sound file was not found.")
            raise

        except PermissionError:
            print("PermissionError: There was permission denied reading the sound file.")
            raise
        
        except RuntimeError as rte:
            print(f"The system could't complete because of an internal processing issue:\n{rte}")
            raise       
            
        except ValueError as ve:
            print(f"Unknown Value Error occurred:\n{ve}")
            raise       
            

                
            
            
            

Loading