|
42 | 42 | "from dataclasses import dataclass, field\n", |
43 | 43 | "from pydantic import BaseModel\n", |
44 | 44 | "import typing as t\n", |
| 45 | + "import json\n", |
| 46 | + "from tqdm import tqdm\n", |
| 47 | + "\n", |
| 48 | + "from ragas_annotator.prompt.base import Prompt\n", |
| 49 | + "from ragas_annotator.embedding.base import BaseEmbedding\n", |
45 | 50 | "from ragas_annotator.metric import MetricResult\n", |
46 | 51 | "from ragas_annotator.llm import RagasLLM\n", |
| 52 | + "from ragas_annotator.project.core import Project\n", |
| 53 | + "from ragas_annotator.model.notion_model import NotionModel\n", |
| 54 | + "from ragas_annotator.prompt.dynamic_few_shot import DynamicFewShotPrompt\n", |
| 55 | + "\n", |
47 | 56 | "\n", |
48 | 57 | "@dataclass\n", |
49 | 58 | "class Metric(ABC):\n", |
50 | 59 | " \"\"\"Base class for all metrics in the LLM evaluation library.\"\"\"\n", |
51 | 60 | " name: str\n", |
52 | | - " prompt: str\n", |
| 61 | + " prompt: str | Prompt\n", |
53 | 62 | " llm: RagasLLM\n", |
54 | 63 | " _response_models: t.Dict[bool, t.Type[BaseModel]] = field(\n", |
55 | 64 | " default_factory=dict, init=False, repr=False\n", |
56 | 65 | " )\n", |
57 | 66 | " \n", |
| 67 | + " def __post_init__(self):\n", |
| 68 | + " if isinstance(self.prompt,str):\n", |
| 69 | + " self.prompt = Prompt(self.prompt)\n", |
| 70 | + " \n", |
58 | 71 | " @abstractmethod\n", |
59 | 72 | " def _get_response_model(self, with_reasoning: bool) -> t.Type[BaseModel]:\n", |
60 | 73 | " \"\"\"Get the appropriate response model.\"\"\"\n", |
|
67 | 80 | " \n", |
68 | 81 | " def score(self, reasoning: bool = True, n: int = 1, **kwargs) -> t.Any:\n", |
69 | 82 | " responses = []\n", |
| 83 | + " traces = {}\n", |
| 84 | + " traces[\"input\"] = kwargs\n", |
70 | 85 | " prompt_input = self.prompt.format(**kwargs)\n", |
71 | 86 | " for _ in range(n):\n", |
72 | 87 | " response = self.llm.generate(prompt_input, response_model = self._get_response_model(reasoning)) \n", |
| 88 | + " traces['output'] = response.model_dump()\n", |
73 | 89 | " response = MetricResult(**response.model_dump())\n", |
74 | 90 | " responses.append(response)\n", |
75 | | - " return self._ensemble(responses)\n", |
| 91 | + " results = self._ensemble(responses)\n", |
| 92 | + " results.traces = traces\n", |
| 93 | + " return results\n", |
76 | 94 | "\n", |
77 | 95 | "\n", |
78 | 96 | " async def ascore(self, reasoning: bool = True, n: int = 1, **kwargs) -> MetricResult:\n", |
79 | 97 | " responses = [] # Added missing initialization\n", |
| 98 | + " traces = {}\n", |
| 99 | + " traces[\"input\"] = kwargs\n", |
80 | 100 | " prompt_input = self.prompt.format(**kwargs)\n", |
81 | 101 | " for _ in range(n):\n", |
82 | 102 | " response = await self.llm.agenerate(prompt_input, response_model = self._get_response_model(reasoning))\n", |
| 103 | + " traces['output'] = response.model_dump()\n", |
83 | 104 | " response = MetricResult(**response.model_dump()) # Fixed missing parentheses\n", |
84 | 105 | " responses.append(response)\n", |
85 | | - " return self._ensemble(responses)\n", |
| 106 | + " results = self._ensemble(responses)\n", |
| 107 | + " results.traces = traces\n", |
| 108 | + " return results\n", |
86 | 109 | " \n", |
87 | 110 | " def batch_score(self, inputs: t.List[t.Dict[str, t.Any]], reasoning: bool = True, n: int = 1) -> t.List[t.Any]:\n", |
88 | 111 | " return [self.score(reasoning, n, **input_dict) for input_dict in inputs]\n", |
|
94 | 117 | " async_tasks.append(self.ascore(reasoning=reasoning, n=n, **input_dict))\n", |
95 | 118 | " \n", |
96 | 119 | " # Run all tasks concurrently and return results\n", |
97 | | - " return await asyncio.gather(*async_tasks)" |
| 120 | + " return await asyncio.gather(*async_tasks)\n", |
| 121 | + " \n", |
| 122 | + " def train(self,project:Project, experiment_names: t.List[str], model:NotionModel, embedding_model: BaseEmbedding,method: t.Dict[str, t.Any]):\n", |
| 123 | + " \n", |
| 124 | + " assert isinstance(self.prompt, Prompt)\n", |
| 125 | + " self.prompt = DynamicFewShotPrompt.from_prompt(self.prompt,embedding_model)\n", |
| 126 | + " datasets = []\n", |
| 127 | + " for experiment_name in experiment_names:\n", |
| 128 | + " experiment_data = project.get_experiment(experiment_name,model)\n", |
| 129 | + " experiment_data.load()\n", |
| 130 | + " datasets.append(experiment_data)\n", |
| 131 | + " \n", |
| 132 | + " total_items = sum([len(dataset) for dataset in datasets])\n", |
| 133 | + " with tqdm(total=total_items, desc=\"Processing examples\") as pbar:\n", |
| 134 | + " for dataset in datasets:\n", |
| 135 | + " for row in dataset:\n", |
| 136 | + " if hasattr(row, f'{self.name}_traces'):\n", |
| 137 | + " traces = json.loads(getattr(row, f'{self.name}_traces'))\n", |
| 138 | + " if traces:\n", |
| 139 | + " self.prompt.add_example(traces['input'],traces['output'])\n", |
| 140 | + " pbar.update(1)\n", |
| 141 | + " \n", |
| 142 | + " \n", |
| 143 | + " \n", |
| 144 | + " \n", |
| 145 | + " \n", |
| 146 | + " \n", |
| 147 | + " " |
98 | 148 | ] |
99 | 149 | }, |
100 | 150 | { |
|
0 commit comments