-
Notifications
You must be signed in to change notification settings - Fork 65
Dev lld: update 5 metrics and dataset for rags #273
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @pekopoke, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a substantial overhaul of the RAG evaluation capabilities within the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant refactoring of the RAG evaluation metrics, aligning them with RAGAS implementations. This includes new, more sophisticated prompts and evaluation logic for faithfulness, context precision, context recall, context relevancy, and answer relevancy. A new example script for batch evaluation is also added.
While the new evaluation logic is a great improvement, the implementation has several critical issues that need to be addressed. These include hardcoded API credentials, a missing import that will cause a runtime crash, and incorrect method definitions that will lead to AttributeErrors. There are also thread-safety concerns due to the modification of class-level state during evaluation. Additionally, I've identified opportunities for code cleanup, such as removing dead code and refactoring duplicated logic in the example script.
Please review the detailed comments for specific suggestions on how to resolve these issues.
| except Exception as e: | ||
| attempts += 1 | ||
| log.error(f"发送消息失败 (尝试 {attempts}/3): {e}") | ||
| time.sleep(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| api_key="API-KEY", | ||
| base_url="API-KEY-BASE-URL" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| } | ||
|
|
||
| prompt = """你是一个信息相关性评估专家。你的任务是评估检索到的上下文是否与给定问题相关。 | ||
| def context_relevance_judge1_prompt(query: str, context: str) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function context_relevance_judge1_prompt is defined within the class scope but not as a method. It's then called as a class method cls.context_relevance_judge1_prompt(...) in build_messages, which will raise an AttributeError. Since this function and context_relevance_judge2_prompt don't use cls or self, they should be decorated with @staticmethod.
@staticmethod
def context_relevance_judge1_prompt(query: str, context: str) -> str:| 2. 对每个陈述判断是否可以从上下文归因 | ||
| 3. 计算召回率分数 = (可归因陈述数 / 总陈述数) × 10 | ||
| 4. 以JSON格式返回结果,不要输出其他内容 | ||
| def context_recall_prompt(question: str, context: str, answer: str) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function context_recall_prompt is defined within the class scope but not as a method. It's then called as a class method cls.context_recall_prompt(...) in build_messages, which will raise an AttributeError. Since this function doesn't use cls or self, it should be decorated with @staticmethod.
@staticmethod
def context_recall_prompt(question: str, context: str, answer: str) -> str:| raise ValueError("Answer Relevancy评估需要answer字段") | ||
|
|
||
| # 使用json.dumps()来安全转义响应字符串 | ||
| import json |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def statement_generator_prompt(question: str, answer: str) -> str: | ||
| """ | ||
| Prompt to generate statements from answer (Chinese version). | ||
| **评估流程**: | ||
| 1. 从答案中提取独立的事实陈述 | ||
| 2. 对每个陈述验证是否能从上下文推导 | ||
| 3. 计算忠实陈述的比例 | ||
| Args: | ||
| question: The user's question | ||
| answer: The generated answer | ||
| **判断标准**: | ||
| - faithful (忠实): 陈述可以从上下文中直接推导或明确支持 | ||
| - unfaithful (不忠实): 陈述无法从上下文推导,或与上下文矛盾,或包含上下文中没有的信息 | ||
| Returns: | ||
| Prompt string for statement generation | ||
| """ | ||
| safe_question = json.dumps(question) | ||
| safe_answer = json.dumps(answer) | ||
|
|
||
| **问题**: | ||
| {0} | ||
| return f"""### 指令 | ||
| **答案**: | ||
| {1} | ||
| 给定一个问题和一个答案,请分析答案中每个句子的复杂性。将每个句子分解为一个或多个完全可理解的陈述。确保在任何陈述中都不使用代词。 | ||
| **上下文**: | ||
| {2} | ||
| ### 问题:{safe_question} | ||
| **任务要求**: | ||
| 1. 提取答案中的独立陈述(每个陈述应该是完整的、可独立验证的事实) | ||
| 2. 对每个陈述判断是否忠实于上下文 | ||
| 3. 计算忠实度分数 = 忠实陈述数量 / 总陈述数量 | ||
| 4. 以JSON格式返回结果,不要输出其他内容 | ||
| ### 答案:{safe_answer} | ||
| **输出格式**: | ||
| ```json | ||
| {{ | ||
| "score": 0-10, | ||
| "reason": "评估理由说明" | ||
| }} | ||
| ``` | ||
| 请以JSON格式返回结果,格式如下: | ||
| ```json | ||
| { | ||
| "statements": [ | ||
| "陈述1", | ||
| "陈述2", | ||
| "陈述3" | ||
| ] | ||
| } | ||
| ``` | ||
| 其中score为0-10之间的整数,10表示完全忠实,0表示完全不忠实。 | ||
| """ | ||
| 请不要输出其他内容,只返回JSON格式的结果。 | ||
| """ | ||
|
|
||
| def faithfulness_judge_prompt(context: str, statements: List[str]) -> str: | ||
| """ | ||
| Prompt to judge faithfulness of statements (Chinese version). | ||
| Args: | ||
| context: The retrieved context | ||
| statements: List of statements to evaluate | ||
| Returns: | ||
| Prompt string for faithfulness judgment | ||
| """ | ||
| safe_context = json.dumps(context) | ||
| safe_statements = json.dumps(statements) | ||
|
|
||
| return f"""### 指令 | ||
| 你的任务是根据给定的上下文判断一系列陈述的忠实度。对于每个陈述,如果可以从上下文中直接推导出该陈述,请返回verdict为1;如果无法从上下文中直接推导出该陈述,请返回verdict为0。 | ||
| ### 上下文:{safe_context} | ||
| ### 陈述列表:{safe_statements} | ||
| 请以JSON格式返回结果,格式如下: | ||
| ```json | ||
| { | ||
| "statements": [ | ||
| { | ||
| "statement": "原始陈述,一字不差", | ||
| "reason": "判断理由", | ||
| "verdict": 0或1 | ||
| } | ||
| ] | ||
| } | ||
| ``` | ||
| 请不要输出其他内容,只返回JSON格式的结果。 | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These prompt-generating functions, statement_generator_prompt and faithfulness_judge_prompt, appear to be dead code. The build_messages method now uses a single, large, inline f-string prompt that combines both steps. If these functions are no longer used, they should be removed to improve code clarity.
| def evaluate_from_jsonl(jsonl_path): | ||
| """从JSONL文件读取数据并进行RAG指标评测""" | ||
| logger.info(f"\n从JSONL文件 {jsonl_path} 读取数据进行评测...") | ||
| print(f"\n从JSONL文件 {jsonl_path} 读取数据进行评测...") | ||
|
|
||
| # 配置所有LLM评估器 | ||
| llm_args = EvaluatorLLMArgs( | ||
| key=OPENAI_KEY, | ||
| api_url=OPENAI_URL, | ||
| model=OPENAI_MODEL, | ||
| ) | ||
|
|
||
| # 设置所有评估器的LLM配置 | ||
| LLMRAGFaithfulness.dynamic_config = llm_args | ||
| LLMRAGContextPrecision.dynamic_config = llm_args | ||
| LLMRAGContextRecall.dynamic_config = llm_args | ||
| LLMRAGContextRelevancy.dynamic_config = llm_args | ||
|
|
||
| # 为AnswerRelevancy配置额外的参数(包括embedding模型) | ||
| LLMRAGAnswerRelevancy.dynamic_config = EvaluatorLLMArgs( | ||
| key=OPENAI_KEY, | ||
| api_url=OPENAI_URL, | ||
| model=OPENAI_MODEL, | ||
| parameters={ | ||
| "embedding_model": EMBEDDING_MODEL, | ||
| "strictness": 3, | ||
| "threshold": 5 | ||
| } | ||
| ) | ||
|
|
||
| # 初始化Embedding模型 | ||
| LLMRAGAnswerRelevancy.init_embedding_model(EMBEDDING_MODEL) | ||
|
|
||
| # 读取JSONL文件 | ||
| with open(jsonl_path, 'r', encoding='utf-8') as f: | ||
| total_rows = 0 | ||
|
|
||
| # 初始化累计总分 | ||
| total_faithfulness = 0 | ||
| total_precision = 0 | ||
| total_recall = 0 | ||
| total_relevancy = 0 | ||
| total_answer_relevancy = 0 | ||
|
|
||
| # 遍历每一行数据 | ||
| for line in f: | ||
| total_rows += 1 | ||
|
|
||
| # 解析JSON行 | ||
| row = json.loads(line.strip()) | ||
|
|
||
| logger.info(f"\n处理第 {total_rows} 条数据:") | ||
| logger.info(f"问题: {row['question']}") | ||
| print(f"\n处理第 {total_rows} 条数据:") | ||
| print(f"问题: {row['question']}") | ||
|
|
||
| # 获取retrieved_contexts(支持字符串列表或单个字符串) | ||
| retrieved_contexts = row.get('retrieved_contexts', []) | ||
| if isinstance(retrieved_contexts, str): | ||
| retrieved_contexts = [retrieved_contexts] | ||
|
|
||
| # 创建Data对象 | ||
| data = Data( | ||
| data_id=f"jsonl_row_{total_rows}", | ||
| prompt=row['question'], | ||
| content=row['response'], | ||
| context=retrieved_contexts, | ||
| reference=row.get('reference', '') # 标准答案是可选的 | ||
| ) | ||
|
|
||
| # # 进行各项指标评测 | ||
| print("\n1. 忠实度 (Faithfulness):") | ||
| faithfulness_result = LLMRAGFaithfulness.eval(data) | ||
| print(f" 状态: {'✅ 通过' if not faithfulness_result.eval_status else '❌ 未通过'}") | ||
| print(f" 分数: {faithfulness_result.score}/10") | ||
| total_faithfulness += faithfulness_result.score | ||
|
|
||
| logger.info("\n2. 上下文精度 (Context Precision):") | ||
| print("\n2. 上下文精度 (Context Precision):") | ||
| precision_result = LLMRAGContextPrecision.eval(data) | ||
| logger.info(f" 状态: {'✅ 通过' if not precision_result.eval_status else '❌ 未通过'}") | ||
| logger.info(f" 分数: {precision_result.score}/10") | ||
| print(f" 状态: {'✅ 通过' if not precision_result.eval_status else '❌ 未通过'}") | ||
| print(f" 分数: {precision_result.score}/10") | ||
| total_precision += precision_result.score | ||
|
|
||
| print("\n3. 上下文召回 (Context Recall):") | ||
| recall_result = LLMRAGContextRecall.eval(data) | ||
| print(f" 状态: {'✅ 通过' if not recall_result.eval_status else '❌ 未通过'}") | ||
| print(f" 分数: {recall_result.score}/10") | ||
| total_recall += recall_result.score | ||
|
|
||
| print("\n4. 上下文相关性 (Context Relevancy):") | ||
| relevancy_result = LLMRAGContextRelevancy.eval(data) | ||
| print(f" 状态: {'✅ 通过' if not relevancy_result.eval_status else '❌ 未通过'}") | ||
| print(f" 分数: {relevancy_result.score}/10") | ||
| total_relevancy += relevancy_result.score | ||
| # | ||
| print("\n5. 答案相关性 (Answer Relevancy):") | ||
| answer_relevancy_result = LLMRAGAnswerRelevancy.eval(data) | ||
| print(f" 状态: {'✅ 通过' if not answer_relevancy_result.eval_status else '❌ 未通过'}") | ||
| print(f" 分数: {answer_relevancy_result.score}/10") | ||
| total_answer_relevancy += answer_relevancy_result.score | ||
|
|
||
| logger.info(f"\n所有 {total_rows} 条数据评测完成!") | ||
| print(f"\n所有 {total_rows} 条数据评测完成!") | ||
|
|
||
| # 计算并打印平均得分 | ||
| if total_rows > 0: | ||
| avg_faithfulness = total_faithfulness / total_rows | ||
| avg_precision = total_precision / total_rows | ||
| avg_recall = total_recall / total_rows | ||
| avg_relevancy = total_relevancy / total_rows | ||
| avg_answer_relevancy = total_answer_relevancy / total_rows | ||
|
|
||
| logger.info("\n" + "=" * 60) | ||
| logger.info("🚀 RAG 指标平均得分") | ||
| logger.info("=" * 60) | ||
| logger.info(f"忠实度 (Faithfulness) 平均值: {avg_faithfulness:.2f}/10") | ||
| logger.info(f"上下文精度 (Context Precision) 平均值: {avg_precision:.2f}/10") | ||
| logger.info(f"上下文召回 (Context Recall) 平均值: {avg_recall:.2f}/10") | ||
| logger.info(f"上下文相关性 (Context Relevancy) 平均值: {avg_relevancy:.2f}/10") | ||
| logger.info(f"答案相关性 (Answer Relevancy) 平均值: {avg_answer_relevancy:.2f}/10") | ||
|
|
||
| # 计算所有指标的总平均值 | ||
| overall_avg = (avg_faithfulness + avg_precision + avg_recall + avg_relevancy + avg_answer_relevancy) / 5 | ||
| logger.info(f"\n📊 综合平均得分: {overall_avg:.2f}/10") | ||
| logger.info("=" * 60) | ||
|
|
||
| print("\n" + "=" * 60) | ||
| print("🚀 RAG 指标平均得分") | ||
| print("=" * 60) | ||
| print(f"忠实度 (Faithfulness) 平均值: {avg_faithfulness:.2f}/10") | ||
| print(f"上下文精度 (Context Precision) 平均值: {avg_precision:.2f}/10") | ||
| print(f"上下文召回 (Context Recall) 平均值: {avg_recall:.2f}/10") | ||
| print(f"上下文相关性 (Context Relevancy) 平均值: {avg_relevancy:.2f}/10") | ||
| print(f"答案相关性 (Answer Relevancy) 平均值: {avg_answer_relevancy:.2f}/10") | ||
|
|
||
| # 计算所有指标的总平均值 | ||
| overall_avg = (avg_faithfulness + avg_precision + avg_recall + avg_relevancy + avg_answer_relevancy) / 5 | ||
| print(f"\n📊 综合平均得分: {overall_avg:.2f}/10") | ||
| print("=" * 60) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is significant code duplication between evaluate_from_jsonl and evaluate_from_csv (which follows a similar pattern). The core logic for setting up evaluators, iterating through data, performing evaluations, and calculating average scores is nearly identical. To improve maintainability, consider refactoring this common logic into a helper function that accepts an iterable of data rows.
| def evaluate_from_jsonl(jsonl_path): | ||
| """从JSONL文件读取数据并进行RAG指标评测""" | ||
| logger.info(f"\n从JSONL文件 {jsonl_path} 读取数据进行评测...") | ||
| print(f"\n从JSONL文件 {jsonl_path} 读取数据进行评测...") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This script configures logging to output to both the console (StreamHandler) and a file. However, it then uses both logger.info() and print() for the same or similar messages throughout the file. This is redundant. It's better to rely solely on the logger for all informational output to maintain consistency and have a single point of control for verbosity.
| try: | ||
| retrieved_contexts = json.loads(row['retrieved_contexts']) | ||
| except json.JSONDecodeError: | ||
| # 如果不是JSON字符串,尝试按列表格式解析 | ||
| retrieved_contexts = [context.strip() for context in row['retrieved_contexts'].strip('[]').split(',')] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fallback logic for parsing retrieved_contexts is fragile. Using split(',') will fail if any of the context strings themselves contain a comma. For a more robust solution, it would be better to enforce and document a single format for this field, such as a JSON-encoded string list, and avoid brittle fallback parsing.
| class ModelRes(BaseModel): | ||
| eval_status: bool = False | ||
| eval_details: EvalDetail = EvalDetail() | ||
| score: Optional[float] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个score看起来不需要加?
| result.score = score | ||
|
|
||
| # 根据分数判断是否通过(默认阈值5,满分10分) | ||
| threshold = 5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个固定值感觉可以去掉
No description provided.