Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions dingo/model/llm/rag/llm_rag_answer_relevancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# 用于embedding的模型,支持OpenAI和HuggingFace
class EmbeddingModel:
"""Embedding模型接口,支持OpenAI和HuggingFace模型"""
def __init__(self, model_name: str = "text-embedding-3-large", is_openai: bool = True):
def __init__(self, model_name: str = "text-embedding-3-large", is_openai: bool = True, api_key: str = None, base_url: str = None):
self.is_openai = is_openai
self.model_name = model_name

Expand All @@ -32,8 +32,8 @@ def __init__(self, model_name: str = "text-embedding-3-large", is_openai: bool =

from openai import OpenAI
self.client = OpenAI(
api_key="API-KEY",
base_url="API-KEY-BASE-URL"
api_key=api_key,
base_url=base_url
)
else:
# 使用HuggingFace Embeddings
Expand Down Expand Up @@ -127,7 +127,18 @@ def init_embedding_model(cls, model_name: str = "text-embedding-3-large"):
"""初始化embedding模型"""
# 检查是否是OpenAI模型
is_openai = model_name.startswith("text-embedding-")
cls.embedding_model = EmbeddingModel(model_name, is_openai)
api_key = None
base_url = None
if is_openai:
# 从配置中获取API密钥和base_url
if not cls.dynamic_config.key:
raise ValueError("key cannot be empty in llm config.")
elif not cls.dynamic_config.api_url:
raise ValueError("api_url cannot be empty in llm config.")
else:
api_key = cls.dynamic_config.key
base_url = cls.dynamic_config.api_url
Comment on lines +133 to +140
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The if/elif/else structure for validating key and api_url is a bit complex. You can simplify this by using two separate if statements to check for each required parameter. This makes the validation logic more direct and easier to read.

            # 从配置中获取API密钥和base_url
            if not cls.dynamic_config.key:
                raise ValueError("key cannot be empty in llm config.")
            if not cls.dynamic_config.api_url:
                raise ValueError("api_url cannot be empty in llm config.")
            api_key = cls.dynamic_config.key
            base_url = cls.dynamic_config.api_url

cls.embedding_model = EmbeddingModel(model_name, is_openai, api_key, base_url)

@classmethod
def build_messages(cls, input_data: Data) -> List:
Expand Down Expand Up @@ -265,7 +276,7 @@ def eval(cls, input_data: Data) -> ModelRes:
result = ModelRes()
result.score = score

# 根据分数判断是否通过(默认阈值5,满分10分)
# 根据分数判断是否通过,默认阈值为5
threshold = 5
if hasattr(cls, 'dynamic_config') and cls.dynamic_config.parameters:
threshold = cls.dynamic_config.parameters.get('threshold', 5)
Expand Down
2 changes: 1 addition & 1 deletion dingo/model/llm/rag/llm_rag_context_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def process_response(cls, responses: List[str]) -> ModelRes:
result = ModelRes()
result.score = score

# 根据分数判断是否通过(默认阈值5,满分10分)
# 根据分数判断是否通过,默认阈值为5
threshold = 5
if hasattr(cls, 'dynamic_config') and cls.dynamic_config.parameters:
threshold = cls.dynamic_config.parameters.get('threshold', 5)
Expand Down
3 changes: 2 additions & 1 deletion dingo/model/llm/rag/llm_rag_context_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class LLMRAGContextRecall(BaseOpenAI):

prompt = """上下文召回评估提示词,用于分类陈述归因"""

@staticmethod
def context_recall_prompt(question: str, context: str, answer: str) -> str:
"""
生成上下文召回评估的提示词
Expand Down Expand Up @@ -200,7 +201,7 @@ def process_response(cls, response: str) -> ModelRes:
result = ModelRes()
result.score = score

# 根据分数判断是否通过(默认阈值5,满分10分)
# 根据分数判断是否通过,默认阈值为5
threshold = 5
if hasattr(cls, 'dynamic_config') and cls.dynamic_config.parameters:
threshold = cls.dynamic_config.parameters.get('threshold', 5)
Expand Down
4 changes: 3 additions & 1 deletion dingo/model/llm/rag/llm_rag_context_relevancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class LLMRAGContextRelevancy(BaseOpenAI):
"source_frameworks": "Ragas + DeepEval + TruLens"
}

@staticmethod
def context_relevance_judge1_prompt(query: str, context: str) -> str:
"""
First judge template for context relevance evaluation (Chinese version).
Expand Down Expand Up @@ -80,6 +81,7 @@ def context_relevance_judge1_prompt(query: str, context: str) -> str:
请不要尝试解释。
分析上下文和问题后,相关性分数为 """

@staticmethod
def context_relevance_judge2_prompt(query: str, context: str) -> str:
"""
Second judge template for context relevance evaluation (Chinese version).
Expand Down Expand Up @@ -200,7 +202,7 @@ def process_response(cls, response: str) -> ModelRes:
result = ModelRes()
result.score = score

# 根据分数判断是否通过(默认阈值5,满分10分)
# 根据分数判断是否通过,默认阈值为5
threshold = 5
if hasattr(cls, 'dynamic_config') and cls.dynamic_config.parameters:
threshold = cls.dynamic_config.parameters.get('threshold', 5)
Expand Down
16 changes: 9 additions & 7 deletions dingo/model/llm/rag/llm_rag_faithfulness.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class LLMRAGFaithfulness(BaseOpenAI):
"source_frameworks": "Ragas + DeepEval"
}

@staticmethod
def statement_generator_prompt(question: str, answer: str) -> str:
"""
Prompt to generate statements from answer (Chinese version).
Expand All @@ -67,18 +68,19 @@ def statement_generator_prompt(question: str, answer: str) -> str:

请以JSON格式返回结果,格式如下:
```json
{
{{
"statements": [
"陈述1",
"陈述2",
"陈述3"
]
}
}}
```

请不要输出其他内容,只返回JSON格式的结果。
"""

@staticmethod
def faithfulness_judge_prompt(context: str, statements: List[str]) -> str:
"""
Prompt to judge faithfulness of statements (Chinese version).
Expand All @@ -103,15 +105,15 @@ def faithfulness_judge_prompt(context: str, statements: List[str]) -> str:

请以JSON格式返回结果,格式如下:
```json
{
{{
"statements": [
{
{{
"statement": "原始陈述,一字不差",
"reason": "判断理由",
"verdict": 0或1
}
}}
]
}
}}
```

请不要输出其他内容,只返回JSON格式的结果。
Expand Down Expand Up @@ -284,7 +286,7 @@ def process_response(cls, response: str) -> ModelRes:
result = ModelRes()
result.score = score

# 根据分数判断是否通过(默认阈值5,满分10分)
# 根据分数判断是否通过,默认阈值为5
threshold = 5
if hasattr(cls, 'dynamic_config') and cls.dynamic_config.parameters:
threshold = cls.dynamic_config.parameters.get('threshold', 5)
Expand Down