|
40 | 40 | "from abc import ABC, abstractmethod\n", |
41 | 41 | "import asyncio\n", |
42 | 42 | "from dataclasses import dataclass, field\n", |
| 43 | + "import datasets\n", |
43 | 44 | "from pydantic import BaseModel\n", |
44 | 45 | "import typing as t\n", |
| 46 | + "import json\n", |
| 47 | + "from tqdm import tqdm\n", |
| 48 | + "\n", |
| 49 | + "from ragas_annotator.prompt.base import Prompt\n", |
| 50 | + "from ragas_annotator.embedding.base import RagasEmbedding\n", |
45 | 51 | "from ragas_annotator.metric import MetricResult\n", |
46 | 52 | "from ragas_annotator.llm import RagasLLM\n", |
| 53 | + "from ragas_annotator.project.core import Project\n", |
| 54 | + "from ragas_annotator.model.notion_model import NotionModel\n", |
| 55 | + "from ragas_annotator.prompt.dynamic_few_shot import DynamicFewShotPrompt\n", |
| 56 | + "\n", |
47 | 57 | "\n", |
48 | 58 | "@dataclass\n", |
49 | 59 | "class Metric(ABC):\n", |
50 | 60 | " \"\"\"Base class for all metrics in the LLM evaluation library.\"\"\"\n", |
51 | 61 | " name: str\n", |
52 | | - " prompt: str\n", |
| 62 | + " prompt: str | Prompt\n", |
53 | 63 | " llm: RagasLLM\n", |
54 | 64 | " _response_models: t.Dict[bool, t.Type[BaseModel]] = field(\n", |
55 | 65 | " default_factory=dict, init=False, repr=False\n", |
56 | 66 | " )\n", |
57 | 67 | " \n", |
| 68 | + " def __post_init__(self):\n", |
| 69 | + " if isinstance(self.prompt,str):\n", |
| 70 | + " self.prompt = Prompt(self.prompt)\n", |
| 71 | + " \n", |
58 | 72 | " @abstractmethod\n", |
59 | 73 | " def _get_response_model(self, with_reasoning: bool) -> t.Type[BaseModel]:\n", |
60 | 74 | " \"\"\"Get the appropriate response model.\"\"\"\n", |
|
67 | 81 | " \n", |
68 | 82 | " def score(self, reasoning: bool = True, n: int = 1, **kwargs) -> t.Any:\n", |
69 | 83 | " responses = []\n", |
| 84 | + " traces = {}\n", |
| 85 | + " traces[\"input\"] = kwargs\n", |
70 | 86 | " prompt_input = self.prompt.format(**kwargs)\n", |
71 | 87 | " for _ in range(n):\n", |
72 | 88 | " response = self.llm.generate(prompt_input, response_model = self._get_response_model(reasoning)) \n", |
| 89 | + " traces['output'] = response.model_dump()\n", |
73 | 90 | " response = MetricResult(**response.model_dump())\n", |
74 | 91 | " responses.append(response)\n", |
75 | | - " return self._ensemble(responses)\n", |
| 92 | + " results = self._ensemble(responses)\n", |
| 93 | + " results.traces = traces\n", |
| 94 | + " return results\n", |
76 | 95 | "\n", |
77 | 96 | "\n", |
78 | 97 | " async def ascore(self, reasoning: bool = True, n: int = 1, **kwargs) -> MetricResult:\n", |
79 | 98 | " responses = [] # Added missing initialization\n", |
| 99 | + " traces = {}\n", |
| 100 | + " traces[\"input\"] = kwargs\n", |
80 | 101 | " prompt_input = self.prompt.format(**kwargs)\n", |
81 | 102 | " for _ in range(n):\n", |
82 | 103 | " response = await self.llm.agenerate(prompt_input, response_model = self._get_response_model(reasoning))\n", |
| 104 | + " traces['output'] = response.model_dump()\n", |
83 | 105 | " response = MetricResult(**response.model_dump()) # Fixed missing parentheses\n", |
84 | 106 | " responses.append(response)\n", |
85 | | - " return self._ensemble(responses)\n", |
| 107 | + " results = self._ensemble(responses)\n", |
| 108 | + " results.traces = traces\n", |
| 109 | + " return results\n", |
86 | 110 | " \n", |
87 | 111 | " def batch_score(self, inputs: t.List[t.Dict[str, t.Any]], reasoning: bool = True, n: int = 1) -> t.List[t.Any]:\n", |
88 | 112 | " return [self.score(reasoning, n, **input_dict) for input_dict in inputs]\n", |
|
95 | 119 | " \n", |
96 | 120 | " # Run all tasks concurrently and return results\n", |
97 | 121 | " return await asyncio.gather(*async_tasks)\n", |
98 | | - " \n", |
| 122 | + " \n", |
| 123 | + " def train(self,project:Project, experiment_names: t.List[str], model:NotionModel, embedding_model: RagasEmbedding,method: t.Dict[str, t.Any]):\n", |
| 124 | + " \n", |
| 125 | + " assert isinstance(self.prompt, Prompt)\n", |
| 126 | + " self.prompt = DynamicFewShotPrompt.from_prompt(self.prompt,embedding_model)\n", |
| 127 | + " datasets = []\n", |
| 128 | + " for experiment_name in experiment_names:\n", |
| 129 | + " experiment_data = project.get_experiment(experiment_name,model)\n", |
| 130 | + " experiment_data.load()\n", |
| 131 | + " datasets.append(experiment_data)\n", |
| 132 | + " \n", |
| 133 | + " total_items = sum([len(dataset) for dataset in datasets])\n", |
| 134 | + " with tqdm(total=total_items, desc=\"Processing examples\") as pbar:\n", |
| 135 | + " for dataset in datasets:\n", |
| 136 | + " for row in dataset:\n", |
| 137 | + " if hasattr(row, f'{self.name}_traces'):\n", |
| 138 | + " traces = json.loads(getattr(row, f'{self.name}_traces'))\n", |
| 139 | + " if traces:\n", |
| 140 | + " self.prompt.add_example(traces['input'],traces['output'])\n", |
| 141 | + " pbar.update(1)\n", |
| 142 | + " \n", |
| 143 | + " \n", |
| 144 | + " \n", |
| 145 | + " \n", |
| 146 | + " \n", |
99 | 147 | " \n", |
100 | | - " " |
| 148 | + " " |
101 | 149 | ] |
102 | 150 | }, |
103 | 151 | { |
|
0 commit comments