Skip to content

Conversation

@shaohuzhang1
Copy link
Contributor

feat: Support gemini stt model

@f2c-ci-robot
Copy link

f2c-ci-robot bot commented Dec 19, 2024

Adding the "do-not-merge/release-note-label-needed" label because no release-note block was detected, please follow our release note process to remove it.

Details

Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository.

@f2c-ci-robot
Copy link

f2c-ci-robot bot commented Dec 19, 2024

[APPROVALNOTIFIER] This PR is NOT APPROVED

This pull-request has been approved by:

The full list of commands accepted by this bot can be found here.

Details Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment

{"type": "media", 'mime_type': 'audio/mp3', "data": audio_data}
])
res = client.invoke([msg])
return res.content
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided code looks generally well-structured for using Google Generative AI (GGAI) with OpenAI's Text-to-Speech service. However, there are a few points to consider:

  1. Tokenization: The custom_get_token_ids function uses a tokenizer from a module named TokenizerManage, but it's unclear where this module or its dependencies are defined or implemented. Ensure that the necessary functions are correctly imported and accessible.

  2. Static Method Documentation:

     @staticmethod
     def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
It would be helpful to include type information for `model_credential` within the static method signature.
```python
    @staticmethod
    def new_instance(
        model_type: str,
        model_name: str,
        model_credential: Dict[str, object],
        **model_kwargs
    ):
  1. Logging:
    The current checks use print(response_list) which should likely be replaced with logging statements for better readability and debugging support.

  2. Optional Parameters:
    The default values of max_tokens and temperature are set in the constructor, but they might want to define default values explicitly instead of relying on keyword arguments directly.

  3. Method Comments: While comments do exist in the code, more detailed descriptions can help other developers understand each part of the flow better.

  4. Async/Await Considerations:
    If you're planning to integrate asynchronous operations, ensure all methods supporting such operations are decorated appropriately (async def) along with proper handling of async tasks (e.g., using await).

  5. Dependencies:
    Make sure that all required libraries (langchain_core, typing, google-generativeai, etc.) are installed and available in your environment.

Here is an updated version of the code incorporating some these suggestions:

@@ -0,0 +1,79 @@
+import asyncio
+import io
+from typing import Dict, Optional

+import google.generativeai as genai

+from common.config.tokenizer_manage_config import TokenizerManage
+from setting.models_provider.base_model_provider import MaxKBBaseModel
+from setting.models_provider.impl.base_stt import BaseSpeechToText
+from langchain_core.messages import HumanMessage


def custom_get_token_ids(text: str) -> list[int]:
    """Converts text into token IDs using appropriate tokenization."""
    tokenizer = TokenizerManage.get_tokenizer()
    return tokenizer.encode(text)

def convert_bytes_to_content(audio_data: bytes) -> HumanMessage:
    """
    Converts binary data representing a media file into a LangChain message.

    :param audio_data: Binary content of the audio file.
    :return: A HumanMessage containing the media details.
    """
    return HumanMessage(content=[
        {'type': 'text', 'text': '把音频转成文字'},
        {"type": "media", 'mime_type': 'audio/mp3', "data": audio_data}
    ])

class GeminiSpeechToText(MaxKBBaseModel, BaseSpeechToText):
    """A class implementing speech to text functionality using Google Generative AI."""

    api_key: str
    model: str
    max_tokens: int = 160  # Default value if not specified
    temperature: float = 0.8  # Default value if not specified

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.api_key = kwargs.get('api_key', '')
        self.max_tokens = kwargs.get('max_tokens', self.max_tokens)
        self.temperature = kwargs.get('temperature', self.temperature)

    @staticmethod
    def new_instance(
            model_type: str,
            model_name: str,
            model_credential: Dict[str, object],
            **model_kwargs
    ) -> 'GeminiSpeechToText':
        additional_params = {
            'max_tokens': model_kwargs.get('max_tokens') or ModelConstants.DEFAULT_GEMINI_MAX_TOKENS,
            'temperature': model_kwargs.get('temperature') or ModelConstants.DEFAULT_GEMINI_TEMPERATURE
        }
        return GeminiSpeechToText(
            model=model_name if model_name else 'gemini-pro',
            api_key=model_credential.get('api_key'),
            **additional_params,
        )

    def check_auth(self) -> bool:
        try:
            client = genai.ChatGoogleGenerativeAI(
                model=self.model,
                google_api_key=self.api_key
            )
            response = client.chat("你好")
            # Logging success
            print(f"Authentication successful via {response.status}")
            return True
        except Exception as e:
            # Logging error
            print(f"Failed authentication attempt: {repr(e)}")
            return False

    async def speak_and_transcribe(
            self,
            speech_file_path: str
    ) -> Optional[str]:
        """
        Asynchronously transcribes the contents of a given .wav file.
        
        This method reads the WAV file, processes it through GGAI, 
        and captures the transcription result.

        Args:
          speech_file_path(str): Path to the input .wav audio file.

        Returns:
          An optional string containing the transcribed text; otherwise None upon failure.

        Raises:
          FileNotFoundError: The wav files at path does not exist
		  PermissionError: Access denied to read the wav file
		  RuntimeError: Error communicating with server while attempting transcription.

        Example usage:
          
          await tts.speak_and_transcribe(input_sound="path/to/input.wav")

        Output example:

          The output will show something like:

              [INFO] Authentication successful via ok
  
              [SUCCESSFUL TRANSCRIPTION]
               Hello world! How can I assist you today?
           """
  
        try:
            with open(speech_file_path, mode='rb') as f_read:
                audio_data = f_read.read()

            client = genai.ChatGoogleGenerativeAI(
                model=self.model,
                google_api_key=self.api_key
            )
            
            messages = [
                  {'role': 'system', 'content': ''}
                 ,{
                    "role": "user",
                    "content":convert_bytes_to_content(audio_data),
                }
            ]
            
            resp = await client.generate_async(messages=[messages[1]])
            transcript = ""
            for chunk in resp.streamed_outputs():
                for choice in chunk.choices:
                    if choice.message.role == "assistant":
                        transcript += choice.text[len(chunk.cumulative_end_index):]

            # Log transcript successfully captured
            print("[SUCCESSFUL TRANSCRIPTION]")
            return transcript.strip("\n\r ")  # Remove line breaks
            
       except FileNotFoundError:
            print("FileNotFoundError: The sound file was not found.")
            raise

        except PermissionError:
            print("PermissionError: There was permission denied reading the sound file.")
            raise
        
        except RuntimeError as rte:
            print(f"The system could't complete because of an internal processing issue:\n{rte}")
            raise       
            
        except ValueError as ve:
            print(f"Unknown Value Error occurred:\n{ve}")
            raise       
            

                
            
            
            

return {**model, 'api_key': super().encryption(model.get('api_key', ''))}

def get_model_params_setting_form(self, model_name):
pass
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are my observations:

The code looks mostly consistent with common practices:

  • It uses Python 3 syntax (with # coding=utf-8), which is good.
  • The class and methods follow PEP 8 guidelines.

However, there are some points to consider for improvement:

  1. Comments: Provide more explanation of complex operations like type hinting and method logic.
  2. Error Handling: Consider using logging instead of raising exceptions in all places where error handling might be appropriate.
  3. Code Duplication: The encryption_dict method can be simplified since it simply wraps another function (super().encryption()).

Overall, the code is well-written but could benefit from better documentation and possibly refactoring to improve clarity and maintainability.

.append_default_model_info(model_stt_info_list[0])
.build()
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few minor improvements and optimizations that can be made to enhance the code:

  1. Remove Trailing Commas: You have trailing commas after some dictionary entries (gemini_llm_model_credential, gemini_image_model_credential). These should be removed for consistency.

  2. Consistent Spacing: Ensure consistent spacing around operators (e.g., =) and after commas.

  3. Import Order: It might be helpful to sort the imports alphabetically for better readability and maintainability.

  4. Variable Naming: Use meaningful variable names like stt_models instead of model_stt_info_list.

Here's the revised version with these changes:

@@ -8,7 +8,6 @@
 from setting.models_provider.impl.gemini_model_provider.model.stt import GeminiSpeechToText
 from smartdoc.conf import PROJECT_DIR

-gemini_stt_model_credential = GeminiSTTModelCredential()

-model_info_list = [
+stt_models = [
     ModelInfo('gemini-1.0-pro', '最新的Gemini 1.0 Pro模型,随Google更新而更新',
              ModelTypeConst.GENERAL_MODEL,
              None,
@@ -42,13 +41,18 @@
               GeminiImage),
 ]

+
+models_stt_list = [
+    ModelInfo('gemini-1.5-flash', '最新的Gemini 1.5 Flash模型,随Google更新而更新',
+              ModelTypeConst.STT,
+              gemini_stt_model_credential,
+              GeminiSpeechToText),
+]

+
 model_info_manage = (
     ModelInfoManage.builder()
         .append_model_info_list(model_info_list)
         .append_model_info_list(stt_models)
         .append_default_model_info(default_model)
         .build()
 )

These changes make the code cleaner and more readable.

@liuruibin liuruibin merged commit 24bb7d5 into main Dec 19, 2024
4 checks passed
@liuruibin liuruibin deleted the pr@main@refactor_stt_gemini branch December 19, 2024 10:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants