|
6 | 6 | @date:2024/7/12 17:44 |
7 | 7 | @desc: |
8 | 8 | """ |
9 | | -from typing import Dict |
| 9 | +from typing import Dict, List |
10 | 10 |
|
11 | | -from langchain_community.embeddings import OpenAIEmbeddings |
| 11 | +import openai |
12 | 12 |
|
13 | 13 | from setting.models_provider.base_model_provider import MaxKBBaseModel |
14 | 14 |
|
15 | 15 |
|
16 | | -class OpenAIEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings): |
| 16 | +class OpenAIEmbeddingModel(MaxKBBaseModel): |
| 17 | + model_name: str |
| 18 | + |
| 19 | + def __init__(self, api_key, base_url, model_name: str): |
| 20 | + self.client = openai.OpenAI(api_key=api_key, base_url=base_url).embeddings |
| 21 | + self.model_name = model_name |
| 22 | + |
17 | 23 | @staticmethod |
18 | 24 | def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): |
19 | 25 | return OpenAIEmbeddingModel( |
20 | 26 | api_key=model_credential.get('api_key'), |
21 | | - model=model_name, |
22 | | - openai_api_base=model_credential.get('api_base'), |
| 27 | + model_name=model_name, |
| 28 | + base_url=model_credential.get('api_base'), |
23 | 29 | ) |
| 30 | + |
| 31 | + def embed_query(self, text: str): |
| 32 | + res = self.embed_documents([text]) |
| 33 | + return res[0] |
| 34 | + |
| 35 | + def embed_documents( |
| 36 | + self, texts: List[str], chunk_size: int | None = None |
| 37 | + ) -> List[List[float]]: |
| 38 | + res = self.client.create(input=texts, model=self.model_name, encoding_format="float") |
| 39 | + return [e.embedding for e in res.data] |
0 commit comments