Skip to content

Commit 64b03af

Browse files
liuruibinwxg0103
authored andcommitted
refactor: support volcanic engine embeddings
1 parent 19fabbc commit 64b03af

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed

apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from django.utils.translation import gettext_lazy as _
1616

1717

18-
class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
18+
class VolcanicEmbeddingCredential(BaseForm, BaseModelCredential):
1919
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
2020
raise_exception=True):
2121
model_type_list = provider.get_model_type_list()
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from typing import Dict
22

3-
from langchain_community.embeddings import VolcanoEmbeddings
3+
from langchain_openai import OpenAIEmbeddings
44

55
from setting.models_provider.base_model_provider import MaxKBBaseModel
66

77

8-
class VolcanicEngineEmbeddingModel(MaxKBBaseModel, VolcanoEmbeddings):
8+
class VolcanicEngineEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings):
99
@staticmethod
1010
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
1111
return VolcanicEngineEmbeddingModel(
12-
api_key=model_credential.get('api_key'),
12+
openai_api_key=model_credential.get('api_key'),
1313
model=model_name,
1414
openai_api_base=model_credential.get('api_base'),
15+
check_embedding_ctx_length=False,
1516
)

apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
1515
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
1616
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
17+
from setting.models_provider.impl.volcanic_engine_model_provider.credential.embedding import VolcanicEmbeddingCredential
1718
from setting.models_provider.impl.volcanic_engine_model_provider.credential.image import \
1819
VolcanicEngineImageModelCredential
1920
from setting.models_provider.impl.volcanic_engine_model_provider.credential.tti import VolcanicEngineTTIModelCredential
2021
from setting.models_provider.impl.volcanic_engine_model_provider.credential.tts import VolcanicEngineTTSModelCredential
22+
from setting.models_provider.impl.volcanic_engine_model_provider.model.embedding import VolcanicEngineEmbeddingModel
2123
from setting.models_provider.impl.volcanic_engine_model_provider.model.image import VolcanicEngineImage
2224
from setting.models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel
2325
from setting.models_provider.impl.volcanic_engine_model_provider.credential.stt import VolcanicEngineSTTModelCredential
@@ -82,12 +84,13 @@
8284
),
8385
]
8486

85-
open_ai_embedding_credential = OpenAIEmbeddingCredential()
87+
open_ai_embedding_credential = VolcanicEmbeddingCredential()
8688
model_info_embedding_list = [
8789
ModelInfo('ep-xxxxxxxxxx-yyyy',
8890
_('The user goes to the model inference page of Volcano Ark to create an inference access point. Here, you need to enter ep-xxxxxxxxxx-yyyy to call it.'),
8991
ModelTypeConst.EMBEDDING, open_ai_embedding_credential,
90-
OpenAIEmbeddingModel)]
92+
VolcanicEngineEmbeddingModel)
93+
]
9194

9295
model_info_manage = (
9396
ModelInfoManage.builder()
@@ -97,6 +100,8 @@
97100
.append_default_model_info(model_info_list[2])
98101
.append_default_model_info(model_info_list[3])
99102
.append_default_model_info(model_info_list[4])
103+
.append_model_info_list(model_info_embedding_list)
104+
.append_default_model_info(model_info_embedding_list[0])
100105
.build()
101106
)
102107

0 commit comments

Comments
 (0)