Skip to content

Commit 3524162

Browse files
committed
feat: Support gemini embedding model
1 parent 24bb7d5 commit 3524162

File tree

3 files changed

+82
-0
lines changed

3 files changed

+82
-0
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: embedding.py
6+
@date:2024/7/12 16:45
7+
@desc:
8+
"""
9+
from typing import Dict
10+
11+
from common import forms
12+
from common.exception.app_exception import AppApiException
13+
from common.forms import BaseForm
14+
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
15+
16+
17+
class GeminiEmbeddingCredential(BaseForm, BaseModelCredential):
18+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
19+
raise_exception=True):
20+
model_type_list = provider.get_model_type_list()
21+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
22+
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
23+
24+
for key in ['api_key']:
25+
if key not in model_credential:
26+
if raise_exception:
27+
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
28+
else:
29+
return False
30+
try:
31+
model = provider.get_model(model_type, model_name, model_credential)
32+
model.embed_query('你好')
33+
except Exception as e:
34+
if isinstance(e, AppApiException):
35+
raise e
36+
if raise_exception:
37+
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
38+
else:
39+
return False
40+
return True
41+
42+
def encryption_dict(self, model: Dict[str, object]):
43+
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
44+
45+
api_key = forms.PasswordInputField('API Key', required=True)

apps/setting/models_provider/impl/gemini_model_provider/gemini_model_provider.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
from common.util.file_util import get_file_content
1212
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
1313
ModelInfoManage
14+
from setting.models_provider.impl.gemini_model_provider.credential.embedding import GeminiEmbeddingCredential
1415
from setting.models_provider.impl.gemini_model_provider.credential.image import GeminiImageModelCredential
1516
from setting.models_provider.impl.gemini_model_provider.credential.llm import GeminiLLMModelCredential
1617
from setting.models_provider.impl.gemini_model_provider.credential.stt import GeminiSTTModelCredential
18+
from setting.models_provider.impl.gemini_model_provider.model.embedding import GeminiEmbeddingModel
1719
from setting.models_provider.impl.gemini_model_provider.model.image import GeminiImage
1820
from setting.models_provider.impl.gemini_model_provider.model.llm import GeminiChatModel
1921
from setting.models_provider.impl.gemini_model_provider.model.stt import GeminiSpeechToText
@@ -22,6 +24,7 @@
2224
gemini_llm_model_credential = GeminiLLMModelCredential()
2325
gemini_image_model_credential = GeminiImageModelCredential()
2426
gemini_stt_model_credential = GeminiSTTModelCredential()
27+
gemini_embedding_model_credential = GeminiEmbeddingCredential()
2528

2629
model_info_list = [
2730
ModelInfo('gemini-1.0-pro', '最新的Gemini 1.0 Pro模型,随Google更新而更新',
@@ -56,14 +59,23 @@
5659
GeminiSpeechToText),
5760
]
5861

62+
model_embedding_info_list = [
63+
ModelInfo('models/embedding-001', '',
64+
ModelTypeConst.EMBEDDING,
65+
gemini_embedding_model_credential,
66+
GeminiEmbeddingModel),
67+
]
68+
5969
model_info_manage = (
6070
ModelInfoManage.builder()
6171
.append_model_info_list(model_info_list)
6272
.append_model_info_list(model_image_info_list)
6373
.append_model_info_list(model_stt_info_list)
74+
.append_model_info_list(model_embedding_info_list)
6475
.append_default_model_info(model_info_list[0])
6576
.append_default_model_info(model_image_info_list[0])
6677
.append_default_model_info(model_stt_info_list[0])
78+
.append_default_model_info(model_embedding_info_list[0])
6779
.build()
6880
)
6981

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: embedding.py
6+
@date:2024/7/12 17:44
7+
@desc:
8+
"""
9+
from typing import Dict
10+
11+
from langchain_google_genai import GoogleGenerativeAIEmbeddings
12+
13+
from setting.models_provider.base_model_provider import MaxKBBaseModel
14+
15+
16+
class GeminiEmbeddingModel(MaxKBBaseModel, GoogleGenerativeAIEmbeddings):
17+
@staticmethod
18+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
19+
return GoogleGenerativeAIEmbeddings(
20+
google_api_key=model_credential.get('api_key'),
21+
model=model_name,
22+
)
23+
24+
def is_cache_model(self):
25+
return False

0 commit comments

Comments
 (0)