|
14 | 14 | from dingo.model import Model |
15 | 15 | from dingo.model.llm.base_openai import BaseOpenAI |
16 | 16 | from dingo.model.modelres import ModelRes |
17 | | -from dingo.model.response.response_class import ResponseScoreReason |
18 | 17 | from dingo.utils import log |
19 | 18 | from dingo.utils.exception import ConvertJsonError |
20 | 19 |
|
21 | 20 |
|
22 | | -# 用于embedding的模型,支持OpenAI和HuggingFace |
23 | | -class EmbeddingModel: |
24 | | - """Embedding模型接口,支持OpenAI和HuggingFace模型""" |
25 | | - def __init__(self, model_name: str = "text-embedding-3-large", is_openai: bool = True, api_key: str = None, base_url: str = None): |
26 | | - self.is_openai = is_openai |
27 | | - self.model_name = model_name |
28 | | - |
29 | | - if is_openai: |
30 | | - # 使用OpenAI Embeddings |
31 | | - import os |
32 | | - |
33 | | - from openai import OpenAI |
34 | | - self.client = OpenAI( |
35 | | - api_key=api_key, |
36 | | - base_url=base_url |
37 | | - ) |
38 | | - else: |
39 | | - # 使用HuggingFace Embeddings |
40 | | - from sentence_transformers import SentenceTransformer |
41 | | - self.model = SentenceTransformer(model_name) |
42 | | - |
43 | | - def embed_query(self, text: str) -> List[float]: |
44 | | - """生成查询的embedding""" |
45 | | - if self.is_openai: |
46 | | - response = self.client.embeddings.create( |
47 | | - model=self.model_name, |
48 | | - input=text |
49 | | - ) |
50 | | - return response.data[0].embedding |
51 | | - else: |
52 | | - return self.model.encode(text).tolist() |
53 | | - |
54 | | - def embed_documents(self, texts: List[str]) -> List[List[float]]: |
55 | | - """生成多个文档的embedding""" |
56 | | - if self.is_openai: |
57 | | - response = self.client.embeddings.create( |
58 | | - model=self.model_name, |
59 | | - input=texts |
60 | | - ) |
61 | | - return [data.embedding for data in response.data] |
62 | | - else: |
63 | | - return self.model.encode(texts).tolist() |
64 | | - |
65 | | - |
66 | 21 | @Model.llm_register("LLMRAGAnswerRelevancy") |
67 | 22 | class LLMRAGAnswerRelevancy(BaseOpenAI): |
68 | 23 | """ |
@@ -125,20 +80,15 @@ class LLMRAGAnswerRelevancy(BaseOpenAI): |
125 | 80 | @classmethod |
126 | 81 | def init_embedding_model(cls, model_name: str = "text-embedding-3-large"): |
127 | 82 | """初始化embedding模型""" |
128 | | - # 检查是否是OpenAI模型 |
129 | | - is_openai = model_name.startswith("text-embedding-") |
130 | | - api_key = None |
131 | | - base_url = None |
132 | | - if is_openai: |
133 | | - # 从配置中获取API密钥和base_url |
134 | | - if not cls.dynamic_config.key: |
135 | | - raise ValueError("key cannot be empty in llm config.") |
136 | | - elif not cls.dynamic_config.api_url: |
137 | | - raise ValueError("api_url cannot be empty in llm config.") |
138 | | - else: |
139 | | - api_key = cls.dynamic_config.key |
140 | | - base_url = cls.dynamic_config.api_url |
141 | | - cls.embedding_model = EmbeddingModel(model_name, is_openai, api_key, base_url) |
| 83 | + # 确保LLM客户端已经创建 |
| 84 | + if not hasattr(cls, 'client') or cls.client is None: |
| 85 | + cls.create_client() |
| 86 | + |
| 87 | + # 直接使用OpenAI的Embedding API |
| 88 | + cls.embedding_model = { |
| 89 | + 'model_name': model_name, |
| 90 | + 'client': cls.client |
| 91 | + } |
142 | 92 |
|
143 | 93 | @classmethod |
144 | 94 | def build_messages(cls, input_data: Data) -> List: |
@@ -210,8 +160,19 @@ def calculate_similarity(cls, question: str, generated_questions: List[str]) -> |
210 | 160 | cls.init_embedding_model() |
211 | 161 |
|
212 | 162 | # 生成embedding |
213 | | - question_vec = np.asarray(cls.embedding_model.embed_query(question)).reshape(1, -1) |
214 | | - gen_question_vec = np.asarray(cls.embedding_model.embed_documents(generated_questions)).reshape(len(generated_questions), -1) |
| 163 | + # 单个查询的embedding |
| 164 | + question_response = cls.embedding_model['client'].embeddings.create( |
| 165 | + model=cls.embedding_model['model_name'], |
| 166 | + input=question |
| 167 | + ) |
| 168 | + question_vec = np.asarray(question_response.data[0].embedding).reshape(1, -1) |
| 169 | + |
| 170 | + # 多个文档的embedding |
| 171 | + gen_questions_response = cls.embedding_model['client'].embeddings.create( |
| 172 | + model=cls.embedding_model['model_name'], |
| 173 | + input=generated_questions |
| 174 | + ) |
| 175 | + gen_question_vec = np.asarray([data.embedding for data in gen_questions_response.data]).reshape(len(generated_questions), -1) |
215 | 176 |
|
216 | 177 | # 计算余弦相似度 |
217 | 178 | norm = np.linalg.norm(gen_question_vec, axis=1) * np.linalg.norm(question_vec, axis=1) |
|
0 commit comments