-
Notifications
You must be signed in to change notification settings - Fork 2.6k
feat: Vllm whisper model #3901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Vllm whisper model #3901
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| # coding=utf-8 | ||
| import traceback | ||
| from typing import Dict | ||
|
|
||
| from django.utils.translation import gettext_lazy as _, gettext | ||
| from langchain_core.messages import HumanMessage | ||
|
|
||
| from common import forms | ||
| from common.exception.app_exception import AppApiException | ||
| from common.forms import BaseForm, TooltipLabel | ||
| from models_provider.base_model_provider import BaseModelCredential, ValidCode | ||
|
|
||
|
|
||
| class VLLMWhisperModelParams(BaseForm): | ||
| Language = forms.TextInputField( | ||
| TooltipLabel(_('Language'), | ||
| _("If not passed, the default value is 'zh'")), | ||
| required=True, | ||
| default_value='zh', | ||
| ) | ||
|
|
||
|
|
||
| class VLLMWhisperModelCredential(BaseForm, BaseModelCredential): | ||
| api_url = forms.TextInputField('API URL', required=True) | ||
| api_key = forms.PasswordInputField('API Key', required=True) | ||
|
|
||
| def is_valid(self, | ||
| model_type: str, | ||
| model_name, | ||
| model_credential: Dict[str, object], | ||
| model_params, | ||
| provider, | ||
| raise_exception=False): | ||
|
|
||
| model_type_list = provider.get_model_type_list() | ||
|
|
||
| if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): | ||
| raise AppApiException(ValidCode.valid_error.value, | ||
| gettext('{model_type} Model type is not supported').format(model_type=model_type)) | ||
| try: | ||
| model_list = provider.get_base_model_list(model_credential.get('api_url'), model_credential.get('api_key')) | ||
| except Exception as e: | ||
| raise AppApiException(ValidCode.valid_error.value, gettext('API domain name is invalid')) | ||
| exist = provider.get_model_info_by_name(model_list, model_name) | ||
| if len(exist) == 0: | ||
| raise AppApiException(ValidCode.valid_error.value, | ||
| gettext('The model does not exist, please download the model first')) | ||
| model = provider.get_model(model_type, model_name, model_credential, **model_params) | ||
| return True | ||
|
|
||
| def encryption_dict(self, model_info: Dict[str, object]): | ||
| return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))} | ||
|
|
||
| def build_model(self, model_info: Dict[str, object]): | ||
| for key in ['api_key', 'model']: | ||
| if key not in model_info: | ||
| raise AppApiException(500, gettext('{key} is required').format(key=key)) | ||
| self.api_key = model_info.get('api_key') | ||
| return self | ||
|
|
||
| def get_model_params_setting_form(self, model_name): | ||
| return VLLMWhisperModelParams() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| import base64 | ||
| import os | ||
| import traceback | ||
| from typing import Dict | ||
|
|
||
| from openai import OpenAI | ||
|
|
||
| from common.utils.logger import maxkb_logger | ||
| from models_provider.base_model_provider import MaxKBBaseModel | ||
| from models_provider.impl.base_stt import BaseSpeechToText | ||
|
|
||
|
|
||
|
|
||
| class VllmWhisperSpeechToText(MaxKBBaseModel, BaseSpeechToText): | ||
| api_key: str | ||
| api_url: str | ||
| model: str | ||
| params: dict | ||
|
|
||
| def __init__(self, **kwargs): | ||
| super().__init__(**kwargs) | ||
| self.api_key = kwargs.get('api_key') | ||
| self.model = kwargs.get('model') | ||
| self.params = kwargs.get('params') | ||
| self.api_url = kwargs.get('api_url') | ||
|
|
||
| @staticmethod | ||
| def is_cache_model(): | ||
| return False | ||
|
|
||
| @staticmethod | ||
| def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): | ||
| return VllmWhisperSpeechToText( | ||
| model=model_name, | ||
| api_key=model_credential.get('api_key'), | ||
| api_url=model_credential.get('api_url'), | ||
| params=model_kwargs, | ||
| **model_kwargs | ||
| ) | ||
|
|
||
| def check_auth(self): | ||
| cwd = os.path.dirname(os.path.abspath(__file__)) | ||
| with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as audio_file: | ||
| self.speech_to_text(audio_file) | ||
|
|
||
| def speech_to_text(self, audio_file): | ||
| base_url = f"{self.api_url}/v1" | ||
| try: | ||
| client = OpenAI( | ||
| api_key=self.api_key, | ||
| base_url=base_url | ||
| ) | ||
|
|
||
| result = client.audio.transcriptions.create( | ||
| file=audio_file, | ||
| model=self.model, | ||
| language=self.params.get('Language'), | ||
| response_format="json" | ||
| ) | ||
|
|
||
| return result.text | ||
|
|
||
| except Exception as err: | ||
| maxkb_logger.error(f":Error: {str(err)}: {traceback.format_exc()}") | ||
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,20 +10,27 @@ | |
| from models_provider.impl.vllm_model_provider.credential.embedding import VllmEmbeddingCredential | ||
| from models_provider.impl.vllm_model_provider.credential.image import VllmImageModelCredential | ||
| from models_provider.impl.vllm_model_provider.credential.llm import VLLMModelCredential | ||
| from models_provider.impl.vllm_model_provider.credential.whisper_stt import VLLMWhisperModelCredential | ||
| from models_provider.impl.vllm_model_provider.model.embedding import VllmEmbeddingModel | ||
| from models_provider.impl.vllm_model_provider.model.image import VllmImage | ||
| from models_provider.impl.vllm_model_provider.model.llm import VllmChatModel | ||
| from maxkb.conf import PROJECT_DIR | ||
| from django.utils.translation import gettext as _ | ||
|
|
||
| from models_provider.impl.vllm_model_provider.model.whisper_sst import VllmWhisperSpeechToText | ||
|
|
||
| v_llm_model_credential = VLLMModelCredential() | ||
| image_model_credential = VllmImageModelCredential() | ||
| embedding_model_credential = VllmEmbeddingCredential() | ||
| whisper_model_credential = VLLMWhisperModelCredential() | ||
|
|
||
| model_info_list = [ | ||
| ModelInfo('facebook/opt-125m', _('Facebook’s 125M parameter model'), ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel), | ||
| ModelInfo('BAAI/Aquila-7B', _('BAAI’s 7B parameter model'), ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel), | ||
| ModelInfo('BAAI/AquilaChat-7B', _('BAAI’s 13B parameter mode'), ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel), | ||
| ModelInfo('facebook/opt-125m', _('Facebook’s 125M parameter model'), ModelTypeConst.LLM, v_llm_model_credential, | ||
| VllmChatModel), | ||
| ModelInfo('BAAI/Aquila-7B', _('BAAI’s 7B parameter model'), ModelTypeConst.LLM, v_llm_model_credential, | ||
| VllmChatModel), | ||
| ModelInfo('BAAI/AquilaChat-7B', _('BAAI’s 13B parameter mode'), ModelTypeConst.LLM, v_llm_model_credential, | ||
| VllmChatModel), | ||
|
|
||
| ] | ||
|
|
||
|
|
@@ -32,7 +39,15 @@ | |
| ] | ||
|
|
||
| embedding_model_info_list = [ | ||
| ModelInfo('HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5', '', ModelTypeConst.EMBEDDING, embedding_model_credential, VllmEmbeddingModel), | ||
| ModelInfo('HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5', '', ModelTypeConst.EMBEDDING, | ||
| embedding_model_credential, VllmEmbeddingModel), | ||
| ] | ||
|
|
||
| whisper_model_info_list = [ | ||
| ModelInfo('whisper-tiny', '', ModelTypeConst.STT, whisper_model_credential, VllmWhisperSpeechToText), | ||
| ModelInfo('whisper-large-v3-turbo', '', ModelTypeConst.STT, whisper_model_credential, VllmWhisperSpeechToText), | ||
| ModelInfo('whisper-small', '', ModelTypeConst.STT, whisper_model_credential, VllmWhisperSpeechToText), | ||
| ModelInfo('whisper-large-v3', '', ModelTypeConst.STT, whisper_model_credential, VllmWhisperSpeechToText), | ||
| ] | ||
|
|
||
| model_info_manage = ( | ||
|
|
@@ -45,6 +60,8 @@ | |
| .append_default_model_info(image_model_info_list[0]) | ||
| .append_model_info_list(embedding_model_info_list) | ||
| .append_default_model_info(embedding_model_info_list[0]) | ||
| .append_model_info_list(whisper_model_info_list) | ||
| .append_default_model_info(whisper_model_info_list[0]) | ||
| .build() | ||
| ) | ||
|
|
||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code looks mostly correct and should work without significant issues. However, there are a few areas that could be improved:
Here's an updated version with some optimizations: from models_provider.impl.vllm_model_provider.credential.embedding import VllmEmbeddingCredential
from models_provider.impl.vllm_model_provider.credential.image import VllmImageModelCredential
from models_provider.impl.vllm_model_provider.credential.llm import VLLMModelCredential
from models_provider.impl.vllm_model_provider.credential.whisper_stt import VLLMWhisperModelCredential
from models_provider.impl.vllm_model_provider.model.embedding import VllmEmbeddingModel
from models_provider.impl.vllm_model_provider.model.image import VllmImage
from models_provider.impl.vllm_model_provider.model.llm import VllmChatModel
from maxkb.conf import PROJECT_DIR
from django.utils.translation import gettext as _
v_llm_model_credential = VLLMModelCredential()
image_model_credential = VllmImageModelCredential()
embedding_model_credential = VllmEmbeddingCredential()
whisper_model_credential = VLLMWhisperModelCredential()
model_info_lists = [
(VLLMChatModel, v_llm_model_credential),
(VllmEmbeddingModel, embedding_model_credential),
(VllmWhisperSpeechToText, whisper_model_credential)
]
all_models = (
image_model_info_list +
embedding_model_info_list +
whisper_model_info_list
)
# Assuming append_default_model_info handles adding the first element twice if needed
config_management.append_model_info_list(all_models).build()Changes Made:
These changes aim to make the code cleaner and potentially more efficient by avoiding redundancy in the model info lists. |
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No irregularities found. The code looks to be correctly structured for a Django-based form handling with model credential validation. Here are some optimizations you might consider:
Use
gettextdirectly: Since both_('Language')andgettext('{model_type} Model type is not supported').format(model_type=model_type)usegettext, they can be consolidated into a single call.Remove unnecessary imports:
langchain_core.messages.HumanMessageis used but never referenced within this class, so it's safe to remove from the imports list.Encapsulate logic: You could encapsulate some of the exception handling and message formatting in helper functions rather than repeating them across lines.
Consider using context managers: If you anticipate making multiple network requests or database interactions, using async contexts (Python 3.7+) would help manage operations more cleanly.
Here's an updated version with these considerations:
This version introduces helpers like
_create_error_msgfor better readability and consolidates duplicated message creation patterns.