|
42 | 42 | "from dataclasses import dataclass, field\n", |
43 | 43 | "from pydantic import BaseModel\n", |
44 | 44 | "import typing as t\n", |
45 | | - "import json\n", |
46 | 45 | "from tqdm import tqdm\n", |
| 46 | + "import string\n", |
| 47 | + "\n", |
47 | 48 | "\n", |
48 | 49 | "from ragas_annotator.prompt.base import Prompt\n", |
49 | 50 | "from ragas_annotator.embedding.base import BaseEmbedding\n", |
|
76 | 77 | " @abstractmethod\n", |
77 | 78 | " def _ensemble(self, results: t.List[MetricResult]) -> MetricResult:\n", |
78 | 79 | " pass\n", |
79 | | - " \n", |
| 80 | + " \n", |
| 81 | + " def get_variables(self) -> t.List[str]:\n", |
| 82 | + " if isinstance(self.prompt, Prompt):\n", |
| 83 | + " fstr = self.prompt.instruction\n", |
| 84 | + " else:\n", |
| 85 | + " fstr = self.prompt\n", |
| 86 | + " vars = [field_name for _, field_name, _, _ in string.Formatter().parse(fstr) if field_name]\n", |
| 87 | + " return vars\n", |
80 | 88 | " \n", |
81 | 89 | " def score(self, reasoning: bool = True, n: int = 1, **kwargs) -> t.Any:\n", |
82 | 90 | " responses = []\n", |
|
130 | 138 | " datasets.append(experiment_data)\n", |
131 | 139 | " \n", |
132 | 140 | " total_items = sum([len(dataset) for dataset in datasets])\n", |
| 141 | + " input_vars = self.get_variables()\n", |
| 142 | + " output_vars = [self.name, f'{self.name}_reason']\n", |
133 | 143 | " with tqdm(total=total_items, desc=\"Processing examples\") as pbar:\n", |
134 | 144 | " for dataset in datasets:\n", |
135 | 145 | " for row in dataset:\n", |
136 | | - " if hasattr(row, f'{self.name}_traces'):\n", |
137 | | - " traces = json.loads(getattr(row, f'{self.name}_traces'))\n", |
138 | | - " if traces:\n", |
139 | | - " self.prompt.add_example(traces['input'],traces['output'])\n", |
| 146 | + " inputs = {var: getattr(row, var) for var in input_vars if hasattr(row, var)}\n", |
| 147 | + " output = {var: getattr(row, var) for var in output_vars if hasattr(row, var)}\n", |
| 148 | + " if output:\n", |
| 149 | + " self.prompt.add_example(inputs,output)\n", |
140 | 150 | " pbar.update(1)\n", |
141 | 151 | " \n", |
142 | 152 | " \n", |
|
160 | 170 | "execution_count": null, |
161 | 171 | "id": "fcf208fa", |
162 | 172 | "metadata": {}, |
163 | | - "outputs": [], |
| 173 | + "outputs": [ |
| 174 | + { |
| 175 | + "data": { |
| 176 | + "text/plain": [ |
| 177 | + "100" |
| 178 | + ] |
| 179 | + }, |
| 180 | + "execution_count": null, |
| 181 | + "metadata": {}, |
| 182 | + "output_type": "execute_result" |
| 183 | + } |
| 184 | + ], |
164 | 185 | "source": [ |
165 | 186 | "#| eval: false\n", |
166 | 187 | "\n", |
|
189 | 210 | "my_metric = CustomMetric(name=\"example\", prompt=\"What is the result of {input}?\", llm=llm)\n", |
190 | 211 | "my_metric.score(input=\"test\")" |
191 | 212 | ] |
| 213 | + }, |
| 214 | + { |
| 215 | + "cell_type": "code", |
| 216 | + "execution_count": null, |
| 217 | + "metadata": {}, |
| 218 | + "outputs": [], |
| 219 | + "source": [] |
192 | 220 | } |
193 | 221 | ], |
194 | 222 | "metadata": { |
|
0 commit comments