Skip to content

Commit 31f9dda

Browse files
committed
changes for few shot learning
1 parent 2459574 commit 31f9dda

File tree

12 files changed

+214
-70
lines changed

12 files changed

+214
-70
lines changed

nbs/metric/base.ipynb

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,35 @@
4040
"from abc import ABC, abstractmethod\n",
4141
"import asyncio\n",
4242
"from dataclasses import dataclass, field\n",
43+
"import datasets\n",
4344
"from pydantic import BaseModel\n",
4445
"import typing as t\n",
46+
"import json\n",
47+
"from tqdm import tqdm\n",
48+
"\n",
49+
"from ragas_annotator.prompt.base import Prompt\n",
50+
"from ragas_annotator.embedding.base import RagasEmbedding\n",
4551
"from ragas_annotator.metric import MetricResult\n",
4652
"from ragas_annotator.llm import RagasLLM\n",
53+
"from ragas_annotator.project.core import Project\n",
54+
"from ragas_annotator.model.notion_model import NotionModel\n",
55+
"from ragas_annotator.prompt.dynamic_few_shot import DynamicFewShotPrompt\n",
56+
"\n",
4757
"\n",
4858
"@dataclass\n",
4959
"class Metric(ABC):\n",
5060
" \"\"\"Base class for all metrics in the LLM evaluation library.\"\"\"\n",
5161
" name: str\n",
52-
" prompt: str\n",
62+
" prompt: str | Prompt\n",
5363
" llm: RagasLLM\n",
5464
" _response_models: t.Dict[bool, t.Type[BaseModel]] = field(\n",
5565
" default_factory=dict, init=False, repr=False\n",
5666
" )\n",
5767
" \n",
68+
" def __post_init__(self):\n",
69+
" if isinstance(self.prompt,str):\n",
70+
" self.prompt = Prompt(self.prompt)\n",
71+
" \n",
5872
" @abstractmethod\n",
5973
" def _get_response_model(self, with_reasoning: bool) -> t.Type[BaseModel]:\n",
6074
" \"\"\"Get the appropriate response model.\"\"\"\n",
@@ -67,22 +81,32 @@
6781
" \n",
6882
" def score(self, reasoning: bool = True, n: int = 1, **kwargs) -> t.Any:\n",
6983
" responses = []\n",
84+
" traces = {}\n",
85+
" traces[\"input\"] = kwargs\n",
7086
" prompt_input = self.prompt.format(**kwargs)\n",
7187
" for _ in range(n):\n",
7288
" response = self.llm.generate(prompt_input, response_model = self._get_response_model(reasoning)) \n",
89+
" traces['output'] = response.model_dump()\n",
7390
" response = MetricResult(**response.model_dump())\n",
7491
" responses.append(response)\n",
75-
" return self._ensemble(responses)\n",
92+
" results = self._ensemble(responses)\n",
93+
" results.traces = traces\n",
94+
" return results\n",
7695
"\n",
7796
"\n",
7897
" async def ascore(self, reasoning: bool = True, n: int = 1, **kwargs) -> MetricResult:\n",
7998
" responses = [] # Added missing initialization\n",
99+
" traces = {}\n",
100+
" traces[\"input\"] = kwargs\n",
80101
" prompt_input = self.prompt.format(**kwargs)\n",
81102
" for _ in range(n):\n",
82103
" response = await self.llm.agenerate(prompt_input, response_model = self._get_response_model(reasoning))\n",
104+
" traces['output'] = response.model_dump()\n",
83105
" response = MetricResult(**response.model_dump()) # Fixed missing parentheses\n",
84106
" responses.append(response)\n",
85-
" return self._ensemble(responses)\n",
107+
" results = self._ensemble(responses)\n",
108+
" results.traces = traces\n",
109+
" return results\n",
86110
" \n",
87111
" def batch_score(self, inputs: t.List[t.Dict[str, t.Any]], reasoning: bool = True, n: int = 1) -> t.List[t.Any]:\n",
88112
" return [self.score(reasoning, n, **input_dict) for input_dict in inputs]\n",
@@ -95,9 +119,33 @@
95119
" \n",
96120
" # Run all tasks concurrently and return results\n",
97121
" return await asyncio.gather(*async_tasks)\n",
98-
" \n",
122+
" \n",
123+
" def train(self,project:Project, experiment_names: t.List[str], model:NotionModel, embedding_model: RagasEmbedding,method: t.Dict[str, t.Any]):\n",
124+
" \n",
125+
" assert isinstance(self.prompt, Prompt)\n",
126+
" self.prompt = DynamicFewShotPrompt.from_prompt(self.prompt,embedding_model)\n",
127+
" datasets = []\n",
128+
" for experiment_name in experiment_names:\n",
129+
" experiment_data = project.get_experiment(experiment_name,model)\n",
130+
" experiment_data.load()\n",
131+
" datasets.append(experiment_data)\n",
132+
" \n",
133+
" total_items = sum([len(dataset) for dataset in datasets])\n",
134+
" with tqdm(total=total_items, desc=\"Processing examples\") as pbar:\n",
135+
" for dataset in datasets:\n",
136+
" for row in dataset:\n",
137+
" if hasattr(row, f'{self.name}_traces'):\n",
138+
" traces = json.loads(getattr(row, f'{self.name}_traces'))\n",
139+
" if traces:\n",
140+
" self.prompt.add_example(traces['input'],traces['output'])\n",
141+
" pbar.update(1)\n",
142+
" \n",
143+
" \n",
144+
" \n",
145+
" \n",
146+
" \n",
99147
" \n",
100-
" "
148+
" "
101149
]
102150
},
103151
{

nbs/metric/decorator.ipynb

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,6 @@
6666
" \n",
6767
" @dataclass\n",
6868
" class CustomMetric(metric_class):\n",
69-
" def _extract_result(self, result, reasoning: bool):\n",
70-
" \"\"\"Extract score and reason from the result.\"\"\"\n",
71-
" if isinstance(result, tuple) and len(result) == 2:\n",
72-
" score, reason = result\n",
73-
" else:\n",
74-
" score, reason = result, None\n",
75-
" \n",
76-
" # Use \"result\" instead of \"score\" for the new MetricResult implementation\n",
77-
" return MetricResult(result=score, reason=reason if reasoning else None)\n",
7869
" \n",
7970
" def _run_sync_in_async(self, func, *args, **kwargs):\n",
8071
" \"\"\"Run a synchronous function in an async context.\"\"\"\n",
@@ -101,7 +92,7 @@
10192
" # Sync function implementation\n",
10293
" result = func(self.llm, self.prompt, **kwargs)\n",
10394
" \n",
104-
" return self._extract_result(result, reasoning)\n",
95+
" return result\n",
10596
" except Exception as e:\n",
10697
" # Handle errors gracefully\n",
10798
" error_msg = f\"Error executing metric {self.name}: {str(e)}\"\n",
@@ -120,7 +111,7 @@
120111
" else:\n",
121112
" # For sync functions, run normally\n",
122113
" result = self._run_sync_in_async(func, self.llm, self.prompt, **kwargs)\n",
123-
" return self._extract_result(result, reasoning)\n",
114+
" return result\n",
124115
" \n",
125116
" # Create the metric instance with all parameters\n",
126117
" metric_instance = CustomMetric(\n",
@@ -168,7 +159,7 @@
168159
"#| eval: false\n",
169160
"\n",
170161
"\n",
171-
"from ragas_annotator.metric import DiscreteMetric\n",
162+
"from ragas_annotator.metric import DiscreteMetric, MetricResult\n",
172163
"from pydantic import BaseModel\n",
173164
"\n",
174165
"from ragas_annotator.llm import ragas_llm\n",
@@ -193,7 +184,7 @@
193184
" score = 'low'\n",
194185
" else:\n",
195186
" score = 'high'\n",
196-
" return score,\"reason\"\n",
187+
" return MetricResult(result=score, reason=response.reason)\n",
197188
"\n",
198189
"result = my_metric.score(response='my response') # result\n",
199190
"print(result)\n",

nbs/metric/discrete.ipynb

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,7 @@
2121
"cell_type": "code",
2222
"execution_count": null,
2323
"metadata": {},
24-
"outputs": [
25-
{
26-
"name": "stderr",
27-
"output_type": "stream",
28-
"text": [
29-
"/opt/homebrew/Caskroom/miniforge/base/envs/random/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
30-
" from .autonotebook import tqdm as notebook_tqdm\n"
31-
]
32-
}
33-
],
24+
"outputs": [],
3425
"source": [
3526
"#| export\n",
3627
"import typing as t\n",
@@ -99,8 +90,8 @@
9990
"name": "stdout",
10091
"output_type": "stream",
10192
"text": [
102-
"low\n",
103-
"The response does not provide any specific information or context that can help evaluate its helpfulness.\n"
93+
"med\n",
94+
"The response 'this is my response' is somewhat helpful because it acknowledges receipt or presence of a response but lacks specific information or relevance that would make it more beneficial or informative.\n"
10495
]
10596
}
10697
],
@@ -143,13 +134,23 @@
143134
"name": "stdout",
144135
"output_type": "stream",
145136
"text": [
146-
"high\n",
147-
"reason\n"
137+
"None\n",
138+
"Error executing metric new_metric: MetricResult.__init__() got an unexpected keyword argument 'traces'\n"
139+
]
140+
},
141+
{
142+
"name": "stderr",
143+
"output_type": "stream",
144+
"text": [
145+
"/var/folders/ww/sk5dkfhn673234cmy5w7008r0000gn/T/ipykernel_31199/2702091914.py:15: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.10/migration/\n",
146+
" traces['output'] = response.dict()\n"
148147
]
149148
}
150149
],
151150
"source": [
152151
"#| eval: false\n",
152+
"from ragas_annotator.metric.result import MetricResult\n",
153+
"\n",
153154
"@discrete_metric(llm=llm,\n",
154155
" prompt=\"Evaluate if given answer is helpful\\n\\n{response}\",\n",
155156
" name='new_metric',values=[\"low\",\"med\",\"high\"])\n",
@@ -158,14 +159,17 @@
158159
" class response_model(BaseModel):\n",
159160
" output: t.List[bool]\n",
160161
" reason: str\n",
161-
" \n",
162+
" traces = {}\n",
163+
" traces['input'] = kwargs\n",
162164
" response = llm.generate(prompt.format(**kwargs),response_model=response_model)\n",
165+
" traces['output'] = response.dict()\n",
163166
" total = sum(response.output)\n",
164167
" if total < 1:\n",
165168
" score = 'low'\n",
166169
" else:\n",
167170
" score = 'high'\n",
168-
" return score,\"reason\"\n",
171+
" \n",
172+
" return MetricResult(result=score,reason=response.reason,traces=traces)\n",
169173
"\n",
170174
"result = my_metric.score(response='my response') # result\n",
171175
"print(result)\n",

nbs/metric/numeric.ipynb

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@
147147
"source": [
148148
"\n",
149149
"#| eval: false\n",
150+
"from ragas_annotator.metric import MetricResult\n",
150151
"\n",
151152
"@numeric_metric(llm=llm,\n",
152153
" prompt=\"Evaluate if given answer is helpful\\n\\n{response}\",\n",
@@ -157,13 +158,16 @@
157158
" output: int\n",
158159
" reason: str\n",
159160
" \n",
161+
" traces = {}\n",
162+
" traces['input'] = kwargs\n",
160163
" response = llm.generate(prompt.format(**kwargs),response_model=response_model)\n",
164+
" traces['output'] = response.dict()\n",
161165
" total = response.output\n",
162166
" if total < 1:\n",
163167
" score = 0\n",
164168
" else:\n",
165169
" score = 10\n",
166-
" return score,\"reason\"\n",
170+
" return MetricResult(result=score,reason=response.reason,traces=traces)\n",
167171
"\n",
168172
"result = my_metric.score(response='my response') # result\n",
169173
"result # 10\n",

nbs/metric/ranking.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@
185185
"source": [
186186
"#| eval: false\n",
187187
"\n",
188+
"from ragas_annotator.metric import MetricResult\n",
188189
"\n",
189190
"@ranking_metric(\n",
190191
" llm=llm, # Your language model instance\n",
@@ -197,7 +198,7 @@
197198
" # For example, process the prompt (formatted with candidates) and produce a ranking.\n",
198199
" ranking = [1, 0, 2] # Dummy ranking: second candidate is best, then first, then third.\n",
199200
" reason = \"Ranked based on response clarity and detail.\"\n",
200-
" return ranking, reason\n",
201+
" return MetricResult(result=ranking, reason=reason)\n",
201202
"\n",
202203
"# Using the decorator-based ranking metric:\n",
203204
"result = my_ranking_metric.score(candidates=[\n",

nbs/metric/result.ipynb

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,14 @@
4646
" - RankingMetrics (list results)\n",
4747
" \"\"\"\n",
4848
" \n",
49-
" def __init__(self, result: t.Any, reason: t.Optional[str] = None):\n",
49+
" def __init__(self, result: t.Any, reason: t.Optional[str] = None, traces: t.Optional[t.Dict[str, t.Any]] = None):\n",
50+
" if traces is not None:\n",
51+
" invalid_keys = [key for key in traces.keys() if key not in {\"input\", \"output\"}]\n",
52+
" if invalid_keys:\n",
53+
" raise ValueError(f\"Invalid keys in traces: {invalid_keys}. Allowed keys are 'input' and 'output'.\")\n",
5054
" self._result = result\n",
5155
" self.reason = reason\n",
56+
" self.traces = traces\n",
5257
" \n",
5358
" def __repr__(self):\n",
5459
" return repr(self._result)\n",

nbs/prompt/dynamic_few_shot.ipynb

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,11 @@
194194
" TypeError\n",
195195
" If inputs or output is not a dictionary\n",
196196
" \"\"\"\n",
197-
" \n",
198-
" self.examples.append((inputs, output))\n",
197+
" if (inputs, output) not in self.examples:\n",
198+
" self.examples.append((inputs, output))\n",
199+
" \n",
199200
" # Add to example store\n",
200-
" if self.example_store:\n",
201+
" if isinstance(self.example_store, ExampleStore) and (inputs, output) not in self.example_store._examples:\n",
201202
" self.example_store.add_example(inputs, output)\n",
202203
" \n",
203204
" @classmethod\n",
@@ -230,7 +231,24 @@
230231
"cell_type": "code",
231232
"execution_count": null,
232233
"metadata": {},
233-
"outputs": [],
234+
"outputs": [
235+
{
236+
"name": "stdout",
237+
"output_type": "stream",
238+
"text": [
239+
"Evaluate if given answer Regularly updating your software reduces the risk of vulnerabilities. is same as expected answer Keeping software up to date helps patch known security flaws and prevents exploits.\n",
240+
"\n",
241+
"Examples:\n",
242+
"\n",
243+
"Example 1:\n",
244+
"Input:\n",
245+
"response: Using two-factor authentication greatly enhances account security.\n",
246+
"expected_answer: Two-factor authentication adds a layer of protection by requiring a second form of identity verification.\n",
247+
"Output:\n",
248+
"score: fail\n"
249+
]
250+
}
251+
],
234252
"source": [
235253
"#| eval: false\n",
236254
"from ragas_annotator.prompt import Prompt\n",
@@ -255,22 +273,46 @@
255273
"prompt = DynamicFewShotPrompt.from_prompt(\n",
256274
" prompt,\n",
257275
" embedding_model=embedding,\n",
258-
" num_examples=3\n",
276+
" num_examples=1\n",
277+
")\n",
278+
"\n",
279+
"prompt.add_example(\n",
280+
" {\n",
281+
" \"response\": \"Bananas are high in potassium and great for quick energy.\",\n",
282+
" \"expected_answer\": \"Bananas provide potassium and are a good source of fast-digesting carbohydrates.\"\n",
283+
" },\n",
284+
" {\"score\": \"pass\"}\n",
259285
")\n",
260286
"\n",
287+
"prompt.add_example(\n",
288+
" {\n",
289+
" \"response\": \"Using two-factor authentication greatly enhances account security.\",\n",
290+
" \"expected_answer\": \"Two-factor authentication adds a layer of protection by requiring a second form of identity verification.\"\n",
291+
" },\n",
292+
" {\"score\": \"fail\"}\n",
293+
")\n",
294+
"\n",
295+
"\n",
261296
"prompt.example_store.get_examples(\n",
262-
" {\"response\": \"You can get a full refund if you miss your flight.\", \"expected_answer\": \"Refunds depend on ticket type; only refundable tickets qualify for full refunds.\"})"
297+
"{\n",
298+
" \"response\": \"Regularly updating your software reduces the risk of vulnerabilities.\",\n",
299+
" \"expected_answer\": \"Keeping software up to date helps patch known security flaws and prevents exploits.\"\n",
300+
" })\n",
301+
"\n",
302+
"print(prompt.format(**{\n",
303+
" \"response\": \"Regularly updating your software reduces the risk of vulnerabilities.\",\n",
304+
" \"expected_answer\": \"Keeping software up to date helps patch known security flaws and prevents exploits.\"\n",
305+
" }))"
263306
]
264-
},
265-
{
266-
"cell_type": "code",
267-
"execution_count": null,
268-
"metadata": {},
269-
"outputs": [],
270-
"source": []
271307
}
272308
],
273-
"metadata": {},
309+
"metadata": {
310+
"kernelspec": {
311+
"display_name": "python3",
312+
"language": "python",
313+
"name": "python3"
314+
}
315+
},
274316
"nbformat": 4,
275317
"nbformat_minor": 2
276318
}

0 commit comments

Comments
 (0)