Skip to content

Commit 9927128

Browse files
committed
refactor: support vllm embeddings image
1 parent ff3f511 commit 9927128

File tree

4 files changed

+159
-0
lines changed

4 files changed

+159
-0
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: embedding.py
6+
@date:2024/7/12 16:45
7+
@desc:
8+
"""
9+
from typing import Dict
10+
11+
from common import forms
12+
from common.exception.app_exception import AppApiException
13+
from common.forms import BaseForm
14+
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
15+
from django.utils.translation import gettext_lazy as _
16+
17+
18+
19+
class VllmEmbeddingCredential(BaseForm, BaseModelCredential):
20+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
21+
raise_exception=True):
22+
model_type_list = provider.get_model_type_list()
23+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
24+
raise AppApiException(ValidCode.valid_error.value, _('{model_type} Model type is not supported').format(model_type=model_type))
25+
26+
for key in ['api_base', 'api_key']:
27+
if key not in model_credential:
28+
if raise_exception:
29+
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
30+
else:
31+
return False
32+
try:
33+
model = provider.get_model(model_type, model_name, model_credential)
34+
model.embed_query('你好')
35+
except Exception as e:
36+
if isinstance(e, AppApiException):
37+
raise e
38+
if raise_exception:
39+
raise AppApiException(ValidCode.valid_error.value, _('Verification failed, please check whether the parameters are correct: {error}').format(error=str(e)))
40+
else:
41+
return False
42+
return True
43+
44+
def encryption_dict(self, model: Dict[str, object]):
45+
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
46+
47+
api_base = forms.TextInputField('API Url', required=True)
48+
api_key = forms.PasswordInputField('API Key', required=True)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# coding=utf-8
2+
import base64
3+
import os
4+
from typing import Dict
5+
6+
from langchain_core.messages import HumanMessage
7+
8+
from common import forms
9+
from common.exception.app_exception import AppApiException
10+
from common.forms import BaseForm, TooltipLabel
11+
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
12+
from django.utils.translation import gettext_lazy as _
13+
14+
class VllmImageModelParams(BaseForm):
15+
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
16+
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
17+
required=True, default_value=0.7,
18+
_min=0.1,
19+
_max=1.0,
20+
_step=0.01,
21+
precision=2)
22+
23+
max_tokens = forms.SliderField(
24+
TooltipLabel(_('Output the maximum Tokens'),
25+
_('Specify the maximum number of tokens that the model can generate')),
26+
required=True, default_value=800,
27+
_min=1,
28+
_max=100000,
29+
_step=1,
30+
precision=0)
31+
32+
33+
34+
class VllmImageModelCredential(BaseForm, BaseModelCredential):
35+
api_base = forms.TextInputField('API Url', required=True)
36+
api_key = forms.PasswordInputField('API Key', required=True)
37+
38+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
39+
raise_exception=False):
40+
model_type_list = provider.get_model_type_list()
41+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
42+
raise AppApiException(ValidCode.valid_error.value, _('{model_type} Model type is not supported').format(model_type=model_type))
43+
44+
for key in ['api_base', 'api_key']:
45+
if key not in model_credential:
46+
if raise_exception:
47+
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
48+
else:
49+
return False
50+
try:
51+
model = provider.get_model(model_type, model_name, model_credential, **model_params)
52+
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
53+
for chunk in res:
54+
print(chunk)
55+
except Exception as e:
56+
if isinstance(e, AppApiException):
57+
raise e
58+
if raise_exception:
59+
raise AppApiException(ValidCode.valid_error.value, _('Verification failed, please check whether the parameters are correct: {error}').format(error=str(e)))
60+
else:
61+
return False
62+
return True
63+
64+
def encryption_dict(self, model: Dict[str, object]):
65+
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
66+
67+
def get_model_params_setting_form(self, model_name):
68+
return VllmImageModelParams()
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: embedding.py
6+
@date:2024/7/12 17:44
7+
@desc:
8+
"""
9+
from typing import Dict
10+
11+
from langchain_community.embeddings import OpenAIEmbeddings
12+
13+
from setting.models_provider.base_model_provider import MaxKBBaseModel
14+
15+
16+
class VllmEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings):
17+
@staticmethod
18+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
19+
return VllmEmbeddingModel(
20+
model=model_name,
21+
openai_api_key=model_credential.get('api_key'),
22+
openai_api_base=model_credential.get('api_base'),
23+
)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from typing import Dict
2+
3+
from setting.models_provider.base_model_provider import MaxKBBaseModel
4+
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
5+
6+
7+
class VllmImage(MaxKBBaseModel, BaseChatOpenAI):
8+
9+
@staticmethod
10+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
11+
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
12+
return VllmImage(
13+
model_name=model_name,
14+
openai_api_base=model_credential.get('api_base'),
15+
openai_api_key=model_credential.get('api_key'),
16+
# stream_options={"include_usage": True},
17+
streaming=True,
18+
stream_usage=True,
19+
**optional_params,
20+
)

0 commit comments

Comments
 (0)