Skip to content

Commit 0a82a1e

Browse files
committed
patch for fixing few shot optim
1 parent b9fa013 commit 0a82a1e

File tree

4 files changed

+54
-14
lines changed

4 files changed

+54
-14
lines changed

nbs/metric/base.ipynb

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@
4242
"from dataclasses import dataclass, field\n",
4343
"from pydantic import BaseModel\n",
4444
"import typing as t\n",
45-
"import json\n",
4645
"from tqdm import tqdm\n",
46+
"import string\n",
47+
"\n",
4748
"\n",
4849
"from ragas_annotator.prompt.base import Prompt\n",
4950
"from ragas_annotator.embedding.base import BaseEmbedding\n",
@@ -76,7 +77,14 @@
7677
" @abstractmethod\n",
7778
" def _ensemble(self, results: t.List[MetricResult]) -> MetricResult:\n",
7879
" pass\n",
79-
" \n",
80+
" \n",
81+
" def get_variables(self) -> t.List[str]:\n",
82+
" if isinstance(self.prompt, Prompt):\n",
83+
" fstr = self.prompt.instruction\n",
84+
" else:\n",
85+
" fstr = self.prompt\n",
86+
" vars = [field_name for _, field_name, _, _ in string.Formatter().parse(fstr) if field_name]\n",
87+
" return vars\n",
8088
" \n",
8189
" def score(self, reasoning: bool = True, n: int = 1, **kwargs) -> t.Any:\n",
8290
" responses = []\n",
@@ -130,13 +138,15 @@
130138
" datasets.append(experiment_data)\n",
131139
" \n",
132140
" total_items = sum([len(dataset) for dataset in datasets])\n",
141+
" input_vars = self.get_variables()\n",
142+
" output_vars = [self.name, f'{self.name}_reason']\n",
133143
" with tqdm(total=total_items, desc=\"Processing examples\") as pbar:\n",
134144
" for dataset in datasets:\n",
135145
" for row in dataset:\n",
136-
" if hasattr(row, f'{self.name}_traces'):\n",
137-
" traces = json.loads(getattr(row, f'{self.name}_traces'))\n",
138-
" if traces:\n",
139-
" self.prompt.add_example(traces['input'],traces['output'])\n",
146+
" inputs = {var: getattr(row, var) for var in input_vars if hasattr(row, var)}\n",
147+
" output = {var: getattr(row, var) for var in output_vars if hasattr(row, var)}\n",
148+
" if output:\n",
149+
" self.prompt.add_example(inputs,output)\n",
140150
" pbar.update(1)\n",
141151
" \n",
142152
" \n",
@@ -160,7 +170,18 @@
160170
"execution_count": null,
161171
"id": "fcf208fa",
162172
"metadata": {},
163-
"outputs": [],
173+
"outputs": [
174+
{
175+
"data": {
176+
"text/plain": [
177+
"100"
178+
]
179+
},
180+
"execution_count": null,
181+
"metadata": {},
182+
"output_type": "execute_result"
183+
}
184+
],
164185
"source": [
165186
"#| eval: false\n",
166187
"\n",
@@ -189,6 +210,13 @@
189210
"my_metric = CustomMetric(name=\"example\", prompt=\"What is the result of {input}?\", llm=llm)\n",
190211
"my_metric.score(input=\"test\")"
191212
]
213+
},
214+
{
215+
"cell_type": "code",
216+
"execution_count": null,
217+
"metadata": {},
218+
"outputs": [],
219+
"source": []
192220
}
193221
],
194222
"metadata": {

ragas_annotator/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.0.1"
1+
__version__ = "0.0.2"
22
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/init_module.ipynb.
33

44
# %% auto 0

ragas_annotator/_modidx.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@
183183
'ragas_annotator/metric/base.py'),
184184
'ragas_annotator.metric.base.Metric.batch_score': ( 'metric/base.html#metric.batch_score',
185185
'ragas_annotator/metric/base.py'),
186+
'ragas_annotator.metric.base.Metric.get_variables': ( 'metric/base.html#metric.get_variables',
187+
'ragas_annotator/metric/base.py'),
186188
'ragas_annotator.metric.base.Metric.score': ( 'metric/base.html#metric.score',
187189
'ragas_annotator/metric/base.py'),
188190
'ragas_annotator.metric.base.Metric.train': ( 'metric/base.html#metric.train',

ragas_annotator/metric/base.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
from dataclasses import dataclass, field
1212
from pydantic import BaseModel
1313
import typing as t
14-
import json
1514
from tqdm import tqdm
15+
import string
16+
1617

1718
from ..prompt.base import Prompt
1819
from ..embedding.base import BaseEmbedding
@@ -45,7 +46,14 @@ def _get_response_model(self, with_reasoning: bool) -> t.Type[BaseModel]:
4546
@abstractmethod
4647
def _ensemble(self, results: t.List[MetricResult]) -> MetricResult:
4748
pass
48-
49+
50+
def get_variables(self) -> t.List[str]:
51+
if isinstance(self.prompt, Prompt):
52+
fstr = self.prompt.instruction
53+
else:
54+
fstr = self.prompt
55+
vars = [field_name for _, field_name, _, _ in string.Formatter().parse(fstr) if field_name]
56+
return vars
4957

5058
def score(self, reasoning: bool = True, n: int = 1, **kwargs) -> t.Any:
5159
responses = []
@@ -99,13 +107,15 @@ def train(self,project:Project, experiment_names: t.List[str], model:NotionModel
99107
datasets.append(experiment_data)
100108

101109
total_items = sum([len(dataset) for dataset in datasets])
110+
input_vars = self.get_variables()
111+
output_vars = [self.name, f'{self.name}_reason']
102112
with tqdm(total=total_items, desc="Processing examples") as pbar:
103113
for dataset in datasets:
104114
for row in dataset:
105-
if hasattr(row, f'{self.name}_traces'):
106-
traces = json.loads(getattr(row, f'{self.name}_traces'))
107-
if traces:
108-
self.prompt.add_example(traces['input'],traces['output'])
115+
inputs = {var: getattr(row, var) for var in input_vars if hasattr(row, var)}
116+
output = {var: getattr(row, var) for var in output_vars if hasattr(row, var)}
117+
if output:
118+
self.prompt.add_example(inputs,output)
109119
pbar.update(1)
110120

111121

0 commit comments

Comments
 (0)