|
6 | 6 | @date:2024/10/16 16:34 |
7 | 7 | @desc: |
8 | 8 | """ |
9 | | -from functools import reduce |
10 | 9 | from typing import Dict, List |
11 | 10 |
|
12 | | -from langchain_community.embeddings import DashScopeEmbeddings |
13 | | -from langchain_community.embeddings.dashscope import embed_with_retry |
| 11 | +from openai import OpenAI |
14 | 12 |
|
15 | 13 | from models_provider.base_model_provider import MaxKBBaseModel |
16 | 14 |
|
17 | 15 |
|
18 | | -def proxy_embed_documents(texts: List[str], step_size, embed_documents): |
19 | | - value = [embed_documents(texts[start_index:start_index + step_size]) for start_index in |
20 | | - range(0, len(texts), step_size)] |
21 | | - return reduce(lambda x, y: [*x, *y], value, []) |
| 16 | +class AliyunBaiLianEmbedding(MaxKBBaseModel): |
| 17 | + model_name: str |
| 18 | + optional_params: dict |
22 | 19 |
|
| 20 | + def __init__(self, api_key, base_url, model_name: str, optional_params: dict): |
| 21 | + self.client = OpenAI(api_key=api_key, base_url=base_url).embeddings |
| 22 | + self.model_name = model_name |
| 23 | + self.optional_params = optional_params |
23 | 24 |
|
24 | | -class AliyunBaiLianEmbedding(MaxKBBaseModel, DashScopeEmbeddings): |
25 | 25 | @staticmethod |
26 | 26 | def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): |
| 27 | + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) |
27 | 28 | return AliyunBaiLianEmbedding( |
28 | | - model=model_name, |
29 | | - dashscope_api_key=model_credential.get('dashscope_api_key') |
| 29 | + api_key=model_credential.get('api_key'), |
| 30 | + model_name=model_name, |
| 31 | + base_url=model_credential.get('api_base'), |
| 32 | + optional_params=optional_params |
30 | 33 | ) |
31 | 34 |
|
32 | | - def embed_documents(self, texts: List[str]) -> List[List[float]]: |
33 | | - if self.model == 'text-embedding-v3': |
34 | | - return proxy_embed_documents(texts, 6, self._embed_documents) |
35 | | - return self._embed_documents(texts) |
36 | | - |
37 | | - def _embed_documents(self, texts: List[str]) -> List[List[float]]: |
38 | | - """Call out to DashScope's embedding endpoint for embedding search docs. |
39 | | -
|
40 | | - Args: |
41 | | - texts: The list of texts to embed. |
42 | | - chunk_size: The chunk size of embeddings. If None, will use the chunk size |
43 | | - specified by the class. |
44 | | -
|
45 | | - Returns: |
46 | | - List of embeddings, one for each text. |
47 | | - """ |
48 | | - embeddings = embed_with_retry( |
49 | | - self, input=texts, text_type="document", model=self.model |
50 | | - ) |
51 | | - embedding_list = [item["embedding"] for item in embeddings] |
52 | | - return embedding_list |
53 | | - |
54 | | - def embed_query(self, text: str) -> List[float]: |
55 | | - """Call out to DashScope's embedding endpoint for embedding query text. |
56 | | -
|
57 | | - Args: |
58 | | - text: The text to embed. |
59 | | -
|
60 | | - Returns: |
61 | | - Embedding for the text. |
62 | | - """ |
63 | | - embedding = embed_with_retry( |
64 | | - self, input=[text], text_type="document", model=self.model |
65 | | - )[0]["embedding"] |
66 | | - return embedding |
| 35 | + def embed_query(self, text: str): |
| 36 | + res = self.embed_documents([text]) |
| 37 | + return res[0] |
| 38 | + |
| 39 | + def embed_documents( |
| 40 | + self, texts: List[str], chunk_size: int | None = None |
| 41 | + ) -> List[List[float]]: |
| 42 | + if len(self.optional_params) > 0: |
| 43 | + res = self.client.create( |
| 44 | + input=texts, model=self.model_name, encoding_format="float", |
| 45 | + **self.optional_params |
| 46 | + ) |
| 47 | + else: |
| 48 | + res = self.client.create(input=texts, model=self.model_name, encoding_format="float") |
| 49 | + return [e.embedding for e in res.data] |
0 commit comments