Skip to content

Commit 3c827bf

Browse files
authored
Merge pull request #277 from pekopoke/dev_1209
fix : embedding model change
2 parents 1a66e46 + e69a571 commit 3c827bf

File tree

1 file changed

+22
-61
lines changed

1 file changed

+22
-61
lines changed

dingo/model/llm/rag/llm_rag_answer_relevancy.py

Lines changed: 22 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -14,55 +14,10 @@
1414
from dingo.model import Model
1515
from dingo.model.llm.base_openai import BaseOpenAI
1616
from dingo.model.modelres import ModelRes
17-
from dingo.model.response.response_class import ResponseScoreReason
1817
from dingo.utils import log
1918
from dingo.utils.exception import ConvertJsonError
2019

2120

22-
# 用于embedding的模型,支持OpenAI和HuggingFace
23-
class EmbeddingModel:
24-
"""Embedding模型接口,支持OpenAI和HuggingFace模型"""
25-
def __init__(self, model_name: str = "text-embedding-3-large", is_openai: bool = True, api_key: str = None, base_url: str = None):
26-
self.is_openai = is_openai
27-
self.model_name = model_name
28-
29-
if is_openai:
30-
# 使用OpenAI Embeddings
31-
import os
32-
33-
from openai import OpenAI
34-
self.client = OpenAI(
35-
api_key=api_key,
36-
base_url=base_url
37-
)
38-
else:
39-
# 使用HuggingFace Embeddings
40-
from sentence_transformers import SentenceTransformer
41-
self.model = SentenceTransformer(model_name)
42-
43-
def embed_query(self, text: str) -> List[float]:
44-
"""生成查询的embedding"""
45-
if self.is_openai:
46-
response = self.client.embeddings.create(
47-
model=self.model_name,
48-
input=text
49-
)
50-
return response.data[0].embedding
51-
else:
52-
return self.model.encode(text).tolist()
53-
54-
def embed_documents(self, texts: List[str]) -> List[List[float]]:
55-
"""生成多个文档的embedding"""
56-
if self.is_openai:
57-
response = self.client.embeddings.create(
58-
model=self.model_name,
59-
input=texts
60-
)
61-
return [data.embedding for data in response.data]
62-
else:
63-
return self.model.encode(texts).tolist()
64-
65-
6621
@Model.llm_register("LLMRAGAnswerRelevancy")
6722
class LLMRAGAnswerRelevancy(BaseOpenAI):
6823
"""
@@ -125,20 +80,15 @@ class LLMRAGAnswerRelevancy(BaseOpenAI):
12580
@classmethod
12681
def init_embedding_model(cls, model_name: str = "text-embedding-3-large"):
12782
"""初始化embedding模型"""
128-
# 检查是否是OpenAI模型
129-
is_openai = model_name.startswith("text-embedding-")
130-
api_key = None
131-
base_url = None
132-
if is_openai:
133-
# 从配置中获取API密钥和base_url
134-
if not cls.dynamic_config.key:
135-
raise ValueError("key cannot be empty in llm config.")
136-
elif not cls.dynamic_config.api_url:
137-
raise ValueError("api_url cannot be empty in llm config.")
138-
else:
139-
api_key = cls.dynamic_config.key
140-
base_url = cls.dynamic_config.api_url
141-
cls.embedding_model = EmbeddingModel(model_name, is_openai, api_key, base_url)
83+
# 确保LLM客户端已经创建
84+
if not hasattr(cls, 'client') or cls.client is None:
85+
cls.create_client()
86+
87+
# 直接使用OpenAI的Embedding API
88+
cls.embedding_model = {
89+
'model_name': model_name,
90+
'client': cls.client
91+
}
14292

14393
@classmethod
14494
def build_messages(cls, input_data: Data) -> List:
@@ -210,8 +160,19 @@ def calculate_similarity(cls, question: str, generated_questions: List[str]) ->
210160
cls.init_embedding_model()
211161

212162
# 生成embedding
213-
question_vec = np.asarray(cls.embedding_model.embed_query(question)).reshape(1, -1)
214-
gen_question_vec = np.asarray(cls.embedding_model.embed_documents(generated_questions)).reshape(len(generated_questions), -1)
163+
# 单个查询的embedding
164+
question_response = cls.embedding_model['client'].embeddings.create(
165+
model=cls.embedding_model['model_name'],
166+
input=question
167+
)
168+
question_vec = np.asarray(question_response.data[0].embedding).reshape(1, -1)
169+
170+
# 多个文档的embedding
171+
gen_questions_response = cls.embedding_model['client'].embeddings.create(
172+
model=cls.embedding_model['model_name'],
173+
input=generated_questions
174+
)
175+
gen_question_vec = np.asarray([data.embedding for data in gen_questions_response.data]).reshape(len(generated_questions), -1)
215176

216177
# 计算余弦相似度
217178
norm = np.linalg.norm(gen_question_vec, axis=1) * np.linalg.norm(question_vec, axis=1)

0 commit comments

Comments
 (0)