Skip to content

Commit b7268ce

Browse files
committed
feat: add optional parameters to OpenAIEmbeddingModel for enhanced embedding functionality
1 parent 6b23469 commit b7268ce

File tree

2 files changed

+40
-47
lines changed

2 files changed

+40
-47
lines changed

apps/models_provider/impl/aliyun_bai_lian_model_provider/model/embedding.py

Lines changed: 28 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,61 +6,44 @@
66
@date:2024/10/16 16:34
77
@desc:
88
"""
9-
from functools import reduce
109
from typing import Dict, List
1110

12-
from langchain_community.embeddings import DashScopeEmbeddings
13-
from langchain_community.embeddings.dashscope import embed_with_retry
11+
from openai import OpenAI
1412

1513
from models_provider.base_model_provider import MaxKBBaseModel
1614

1715

18-
def proxy_embed_documents(texts: List[str], step_size, embed_documents):
19-
value = [embed_documents(texts[start_index:start_index + step_size]) for start_index in
20-
range(0, len(texts), step_size)]
21-
return reduce(lambda x, y: [*x, *y], value, [])
16+
class AliyunBaiLianEmbedding(MaxKBBaseModel):
17+
model_name: str
18+
optional_params: dict
2219

20+
def __init__(self, api_key, base_url, model_name: str, optional_params: dict):
21+
self.client = OpenAI(api_key=api_key, base_url=base_url).embeddings
22+
self.model_name = model_name
23+
self.optional_params = optional_params
2324

24-
class AliyunBaiLianEmbedding(MaxKBBaseModel, DashScopeEmbeddings):
2525
@staticmethod
2626
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
27+
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
2728
return AliyunBaiLianEmbedding(
28-
model=model_name,
29-
dashscope_api_key=model_credential.get('dashscope_api_key')
29+
api_key=model_credential.get('api_key'),
30+
model_name=model_name,
31+
base_url=model_credential.get('api_base'),
32+
optional_params=optional_params
3033
)
3134

32-
def embed_documents(self, texts: List[str]) -> List[List[float]]:
33-
if self.model == 'text-embedding-v3':
34-
return proxy_embed_documents(texts, 6, self._embed_documents)
35-
return self._embed_documents(texts)
36-
37-
def _embed_documents(self, texts: List[str]) -> List[List[float]]:
38-
"""Call out to DashScope's embedding endpoint for embedding search docs.
39-
40-
Args:
41-
texts: The list of texts to embed.
42-
chunk_size: The chunk size of embeddings. If None, will use the chunk size
43-
specified by the class.
44-
45-
Returns:
46-
List of embeddings, one for each text.
47-
"""
48-
embeddings = embed_with_retry(
49-
self, input=texts, text_type="document", model=self.model
50-
)
51-
embedding_list = [item["embedding"] for item in embeddings]
52-
return embedding_list
53-
54-
def embed_query(self, text: str) -> List[float]:
55-
"""Call out to DashScope's embedding endpoint for embedding query text.
56-
57-
Args:
58-
text: The text to embed.
59-
60-
Returns:
61-
Embedding for the text.
62-
"""
63-
embedding = embed_with_retry(
64-
self, input=[text], text_type="document", model=self.model
65-
)[0]["embedding"]
66-
return embedding
35+
def embed_query(self, text: str):
36+
res = self.embed_documents([text])
37+
return res[0]
38+
39+
def embed_documents(
40+
self, texts: List[str], chunk_size: int | None = None
41+
) -> List[List[float]]:
42+
if len(self.optional_params) > 0:
43+
res = self.client.create(
44+
input=texts, model=self.model_name, encoding_format="float",
45+
**self.optional_params
46+
)
47+
else:
48+
res = self.client.create(input=texts, model=self.model_name, encoding_format="float")
49+
return [e.embedding for e in res.data]

apps/models_provider/impl/openai_model_provider/model/embedding.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,21 @@
1515

1616
class OpenAIEmbeddingModel(MaxKBBaseModel):
1717
model_name: str
18+
optional_params: dict
1819

19-
def __init__(self, api_key, base_url, model_name: str):
20+
def __init__(self, api_key, base_url, model_name: str, optional_params: dict):
2021
self.client = openai.OpenAI(api_key=api_key, base_url=base_url).embeddings
2122
self.model_name = model_name
23+
self.optional_params = optional_params
2224

2325
@staticmethod
2426
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
27+
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
2528
return OpenAIEmbeddingModel(
2629
api_key=model_credential.get('api_key'),
2730
model_name=model_name,
2831
base_url=model_credential.get('api_base'),
32+
optional_params=optional_params
2933
)
3034

3135
def embed_query(self, text: str):
@@ -35,5 +39,11 @@ def embed_query(self, text: str):
3539
def embed_documents(
3640
self, texts: List[str], chunk_size: int | None = None
3741
) -> List[List[float]]:
38-
res = self.client.create(input=texts, model=self.model_name, encoding_format="float")
42+
if len(self.optional_params) > 0:
43+
res = self.client.create(
44+
input=texts, model=self.model_name, encoding_format="float",
45+
**self.optional_params
46+
)
47+
else:
48+
res = self.client.create(input=texts, model=self.model_name, encoding_format="float")
3949
return [e.embedding for e in res.data]

0 commit comments

Comments
 (0)