|
| 1 | +""" |
| 2 | +RAG Context Precision (上下文精度) LLM评估器 |
| 3 | +
|
| 4 | +基于LLM评估检索上下文的精确度和排序质量。 |
| 5 | +""" |
| 6 | + |
| 7 | +import json |
| 8 | +from typing import List |
| 9 | + |
| 10 | +from dingo.io import Data |
| 11 | +from dingo.model import Model |
| 12 | +from dingo.model.llm.base_openai import BaseOpenAI |
| 13 | +from dingo.model.modelres import ModelRes |
| 14 | +from dingo.model.prompt.prompt_rag_context_precision import PromptRAGContextPrecision |
| 15 | +from dingo.model.response.response_class import ResponseScoreReason |
| 16 | +from dingo.utils import log |
| 17 | +from dingo.utils.exception import ConvertJsonError |
| 18 | + |
| 19 | + |
| 20 | +@Model.llm_register("LLMRAGContextPrecision") |
| 21 | +class LLMRAGContextPrecision(BaseOpenAI): |
| 22 | + """ |
| 23 | + RAG上下文精度评估LLM |
| 24 | +
|
| 25 | + 输入要求: |
| 26 | + - input_data.prompt 或 raw_data['question']: 用户问题 |
| 27 | + - input_data.content 或 raw_data['answer']: 生成的答案 |
| 28 | + - input_data.context 或 raw_data['contexts']: 检索到的上下文列表 |
| 29 | + """ |
| 30 | + |
| 31 | + prompt = PromptRAGContextPrecision |
| 32 | + |
| 33 | + @classmethod |
| 34 | + def build_messages(cls, input_data: Data) -> List: |
| 35 | + """构建LLM输入消息""" |
| 36 | + # 提取字段 |
| 37 | + question = input_data.prompt or input_data.raw_data.get("question", "") |
| 38 | + answer = input_data.content or input_data.raw_data.get("answer", "") |
| 39 | + |
| 40 | + if not answer: |
| 41 | + raise ValueError("Context Precision评估需要answer字段") |
| 42 | + |
| 43 | + # 处理contexts |
| 44 | + contexts = None |
| 45 | + if input_data.context: |
| 46 | + if isinstance(input_data.context, list): |
| 47 | + contexts = input_data.context |
| 48 | + else: |
| 49 | + contexts = [input_data.context] |
| 50 | + elif "contexts" in input_data.raw_data: |
| 51 | + raw_contexts = input_data.raw_data["contexts"] |
| 52 | + if isinstance(raw_contexts, list): |
| 53 | + contexts = raw_contexts |
| 54 | + else: |
| 55 | + contexts = [raw_contexts] |
| 56 | + |
| 57 | + if not contexts: |
| 58 | + raise ValueError("Context Precision评估需要contexts字段") |
| 59 | + |
| 60 | + # 格式化上下文列表 |
| 61 | + contexts_formatted = "\n".join([f"{i + 1}. {ctx}" for i, ctx in enumerate(contexts)]) |
| 62 | + |
| 63 | + # 构建prompt内容 |
| 64 | + prompt_content = cls.prompt.content.format(question, answer, contexts_formatted) |
| 65 | + |
| 66 | + messages = [{"role": "user", "content": prompt_content}] |
| 67 | + |
| 68 | + return messages |
| 69 | + |
| 70 | + @classmethod |
| 71 | + def process_response(cls, response: str) -> ModelRes: |
| 72 | + """处理LLM响应""" |
| 73 | + log.info(f"RAG Context Precision response: {response}") |
| 74 | + |
| 75 | + # 清理响应 |
| 76 | + if response.startswith("```json"): |
| 77 | + response = response[7:] |
| 78 | + if response.startswith("```"): |
| 79 | + response = response[3:] |
| 80 | + if response.endswith("```"): |
| 81 | + response = response[:-3] |
| 82 | + |
| 83 | + try: |
| 84 | + response_json = json.loads(response.strip()) |
| 85 | + except json.JSONDecodeError: |
| 86 | + raise ConvertJsonError(f"Convert to JSON format failed: {response}") |
| 87 | + |
| 88 | + # 解析响应 |
| 89 | + response_model = ResponseScoreReason(**response_json) |
| 90 | + |
| 91 | + result = ModelRes() |
| 92 | + result.score = response_model.score |
| 93 | + |
| 94 | + # 根据分数判断是否通过(默认阈值5,满分10分) |
| 95 | + threshold = 5 |
| 96 | + if hasattr(cls, 'dynamic_config') and cls.dynamic_config.parameters: |
| 97 | + threshold = cls.dynamic_config.parameters.get('threshold', 5) |
| 98 | + |
| 99 | + if response_model.score >= threshold: |
| 100 | + result.error_status = False |
| 101 | + result.type = "QUALITY_GOOD" |
| 102 | + result.name = "CONTEXT_PRECISION_PASS" |
| 103 | + result.reason = [f"上下文精度评估通过 (分数: {response_model.score}/10)\n{response_model.reason}"] |
| 104 | + else: |
| 105 | + result.error_status = True |
| 106 | + result.type = cls.prompt.metric_type |
| 107 | + result.name = cls.prompt.__name__ |
| 108 | + result.reason = [f"上下文精度评估未通过 (分数: {response_model.score}/10)\n{response_model.reason}"] |
| 109 | + |
| 110 | + return result |
0 commit comments