Skip to content

Commit 3ba3f46

Browse files
authored
Merge pull request #4 from explodinggradients/prompt
prompt and few shot learning
2 parents 6142800 + 09eec74 commit 3ba3f46

File tree

19 files changed

+2297
-56
lines changed

19 files changed

+2297
-56
lines changed

nbs/embedding/base.ipynb

Lines changed: 1150 additions & 0 deletions
Large diffs are not rendered by default.

nbs/metric/base.ipynb

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,32 @@
4242
"from dataclasses import dataclass, field\n",
4343
"from pydantic import BaseModel\n",
4444
"import typing as t\n",
45+
"import json\n",
46+
"from tqdm import tqdm\n",
47+
"\n",
48+
"from ragas_annotator.prompt.base import Prompt\n",
49+
"from ragas_annotator.embedding.base import BaseEmbedding\n",
4550
"from ragas_annotator.metric import MetricResult\n",
4651
"from ragas_annotator.llm import RagasLLM\n",
52+
"from ragas_annotator.project.core import Project\n",
53+
"from ragas_annotator.model.notion_model import NotionModel\n",
54+
"from ragas_annotator.prompt.dynamic_few_shot import DynamicFewShotPrompt\n",
55+
"\n",
4756
"\n",
4857
"@dataclass\n",
4958
"class Metric(ABC):\n",
5059
" \"\"\"Base class for all metrics in the LLM evaluation library.\"\"\"\n",
5160
" name: str\n",
52-
" prompt: str\n",
61+
" prompt: str | Prompt\n",
5362
" llm: RagasLLM\n",
5463
" _response_models: t.Dict[bool, t.Type[BaseModel]] = field(\n",
5564
" default_factory=dict, init=False, repr=False\n",
5665
" )\n",
5766
" \n",
67+
" def __post_init__(self):\n",
68+
" if isinstance(self.prompt,str):\n",
69+
" self.prompt = Prompt(self.prompt)\n",
70+
" \n",
5871
" @abstractmethod\n",
5972
" def _get_response_model(self, with_reasoning: bool) -> t.Type[BaseModel]:\n",
6073
" \"\"\"Get the appropriate response model.\"\"\"\n",
@@ -67,22 +80,32 @@
6780
" \n",
6881
" def score(self, reasoning: bool = True, n: int = 1, **kwargs) -> t.Any:\n",
6982
" responses = []\n",
83+
" traces = {}\n",
84+
" traces[\"input\"] = kwargs\n",
7085
" prompt_input = self.prompt.format(**kwargs)\n",
7186
" for _ in range(n):\n",
7287
" response = self.llm.generate(prompt_input, response_model = self._get_response_model(reasoning)) \n",
88+
" traces['output'] = response.model_dump()\n",
7389
" response = MetricResult(**response.model_dump())\n",
7490
" responses.append(response)\n",
75-
" return self._ensemble(responses)\n",
91+
" results = self._ensemble(responses)\n",
92+
" results.traces = traces\n",
93+
" return results\n",
7694
"\n",
7795
"\n",
7896
" async def ascore(self, reasoning: bool = True, n: int = 1, **kwargs) -> MetricResult:\n",
7997
" responses = [] # Added missing initialization\n",
98+
" traces = {}\n",
99+
" traces[\"input\"] = kwargs\n",
80100
" prompt_input = self.prompt.format(**kwargs)\n",
81101
" for _ in range(n):\n",
82102
" response = await self.llm.agenerate(prompt_input, response_model = self._get_response_model(reasoning))\n",
103+
" traces['output'] = response.model_dump()\n",
83104
" response = MetricResult(**response.model_dump()) # Fixed missing parentheses\n",
84105
" responses.append(response)\n",
85-
" return self._ensemble(responses)\n",
106+
" results = self._ensemble(responses)\n",
107+
" results.traces = traces\n",
108+
" return results\n",
86109
" \n",
87110
" def batch_score(self, inputs: t.List[t.Dict[str, t.Any]], reasoning: bool = True, n: int = 1) -> t.List[t.Any]:\n",
88111
" return [self.score(reasoning, n, **input_dict) for input_dict in inputs]\n",
@@ -94,7 +117,34 @@
94117
" async_tasks.append(self.ascore(reasoning=reasoning, n=n, **input_dict))\n",
95118
" \n",
96119
" # Run all tasks concurrently and return results\n",
97-
" return await asyncio.gather(*async_tasks)"
120+
" return await asyncio.gather(*async_tasks)\n",
121+
" \n",
122+
" def train(self,project:Project, experiment_names: t.List[str], model:NotionModel, embedding_model: BaseEmbedding,method: t.Dict[str, t.Any]):\n",
123+
" \n",
124+
" assert isinstance(self.prompt, Prompt)\n",
125+
" self.prompt = DynamicFewShotPrompt.from_prompt(self.prompt,embedding_model)\n",
126+
" datasets = []\n",
127+
" for experiment_name in experiment_names:\n",
128+
" experiment_data = project.get_experiment(experiment_name,model)\n",
129+
" experiment_data.load()\n",
130+
" datasets.append(experiment_data)\n",
131+
" \n",
132+
" total_items = sum([len(dataset) for dataset in datasets])\n",
133+
" with tqdm(total=total_items, desc=\"Processing examples\") as pbar:\n",
134+
" for dataset in datasets:\n",
135+
" for row in dataset:\n",
136+
" if hasattr(row, f'{self.name}_traces'):\n",
137+
" traces = json.loads(getattr(row, f'{self.name}_traces'))\n",
138+
" if traces:\n",
139+
" self.prompt.add_example(traces['input'],traces['output'])\n",
140+
" pbar.update(1)\n",
141+
" \n",
142+
" \n",
143+
" \n",
144+
" \n",
145+
" \n",
146+
" \n",
147+
" "
98148
]
99149
},
100150
{

nbs/metric/decorator.ipynb

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"from dataclasses import dataclass\n",
3232
"from ragas_annotator.metric import MetricResult\n",
3333
"from ragas_annotator.llm import RagasLLM\n",
34+
"from ragas_annotator.prompt.base import Prompt\n",
3435
"\n",
3536
"\n",
3637
"\n",
@@ -45,7 +46,7 @@
4546
" Returns:\n",
4647
" A decorator factory function for the specified metric type\n",
4748
" \"\"\"\n",
48-
" def decorator_factory(llm:RagasLLM, prompt, name: t.Optional[str] = None, **metric_params):\n",
49+
" def decorator_factory(llm:RagasLLM, prompt: t.Union[str, Prompt], name: t.Optional[str] = None, **metric_params):\n",
4950
" \"\"\"\n",
5051
" Creates a decorator that wraps a function into a metric instance.\n",
5152
" \n",
@@ -64,17 +65,9 @@
6465
" metric_name = name or func.__name__\n",
6566
" is_async = inspect.iscoroutinefunction(func)\n",
6667
" \n",
68+
" #TODO: Move to dataclass type implementation\n",
6769
" @dataclass\n",
6870
" class CustomMetric(metric_class):\n",
69-
" def _extract_result(self, result, reasoning: bool):\n",
70-
" \"\"\"Extract score and reason from the result.\"\"\"\n",
71-
" if isinstance(result, tuple) and len(result) == 2:\n",
72-
" score, reason = result\n",
73-
" else:\n",
74-
" score, reason = result, None\n",
75-
" \n",
76-
" # Use \"result\" instead of \"score\" for the new MetricResult implementation\n",
77-
" return MetricResult(result=score, reason=reason if reasoning else None)\n",
7871
" \n",
7972
" def _run_sync_in_async(self, func, *args, **kwargs):\n",
8073
" \"\"\"Run a synchronous function in an async context.\"\"\"\n",
@@ -101,7 +94,7 @@
10194
" # Sync function implementation\n",
10295
" result = func(self.llm, self.prompt, **kwargs)\n",
10396
" \n",
104-
" return self._extract_result(result, reasoning)\n",
97+
" return result\n",
10598
" except Exception as e:\n",
10699
" # Handle errors gracefully\n",
107100
" error_msg = f\"Error executing metric {self.name}: {str(e)}\"\n",
@@ -120,7 +113,7 @@
120113
" else:\n",
121114
" # For sync functions, run normally\n",
122115
" result = self._run_sync_in_async(func, self.llm, self.prompt, **kwargs)\n",
123-
" return self._extract_result(result, reasoning)\n",
116+
" return result\n",
124117
" \n",
125118
" # Create the metric instance with all parameters\n",
126119
" metric_instance = CustomMetric(\n",
@@ -159,16 +152,16 @@
159152
"name": "stdout",
160153
"output_type": "stream",
161154
"text": [
162-
"high\n",
163-
"reason\n"
155+
"low\n",
156+
"The context or details of the user's response ('my response') are not provided, making it impossible to evaluate its helpfulness accurately.\n"
164157
]
165158
}
166159
],
167160
"source": [
168161
"#| eval: false\n",
169162
"\n",
170163
"\n",
171-
"from ragas_annotator.metric import DiscreteMetric\n",
164+
"from ragas_annotator.metric import DiscreteMetric, MetricResult\n",
172165
"from pydantic import BaseModel\n",
173166
"\n",
174167
"from ragas_annotator.llm import ragas_llm\n",
@@ -193,7 +186,7 @@
193186
" score = 'low'\n",
194187
" else:\n",
195188
" score = 'high'\n",
196-
" return score,\"reason\"\n",
189+
" return MetricResult(result=score, reason=response.reason)\n",
197190
"\n",
198191
"result = my_metric.score(response='my response') # result\n",
199192
"print(result)\n",

nbs/metric/discrete.ipynb

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,7 @@
2121
"cell_type": "code",
2222
"execution_count": null,
2323
"metadata": {},
24-
"outputs": [
25-
{
26-
"name": "stderr",
27-
"output_type": "stream",
28-
"text": [
29-
"/opt/homebrew/Caskroom/miniforge/base/envs/random/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
30-
" from .autonotebook import tqdm as notebook_tqdm\n"
31-
]
32-
}
33-
],
24+
"outputs": [],
3425
"source": [
3526
"#| export\n",
3627
"import typing as t\n",
@@ -99,8 +90,17 @@
9990
"name": "stdout",
10091
"output_type": "stream",
10192
"text": [
102-
"low\n",
103-
"The response does not provide any specific information or context that can help evaluate its helpfulness.\n"
93+
"med\n",
94+
"The given input \"this is my response\" is too vague to provide a comprehensive evaluation.\n",
95+
"\n",
96+
"Positives:\n",
97+
"1. Clear Statement: It's a straightforward indication that a response has been provided.\n",
98+
"\n",
99+
"Negatives:\n",
100+
"1. Lack of Context: Without context or additional information, it's impossible to assess the relevance or accuracy of the response.\n",
101+
"2. No Specificity: The response doesn't convey any specific information or insight related to a topic or question.\n",
102+
"\n",
103+
"If this response was intended to be part of a conversation or instruction, more detail would be required to make it highly effective. At present, it serves as a neutral statement without actionable or informative content.\n"
104104
]
105105
}
106106
],
@@ -143,13 +143,15 @@
143143
"name": "stdout",
144144
"output_type": "stream",
145145
"text": [
146-
"high\n",
147-
"reason\n"
146+
"low\n",
147+
"The prompt 'my response' does not provide sufficient information or context for me to evaluate its helpfulness. An answer needs to be specific and provide insight or information relative to a clear question or context.\n"
148148
]
149149
}
150150
],
151151
"source": [
152152
"#| eval: false\n",
153+
"from ragas_annotator.metric.result import MetricResult\n",
154+
"\n",
153155
"@discrete_metric(llm=llm,\n",
154156
" prompt=\"Evaluate if given answer is helpful\\n\\n{response}\",\n",
155157
" name='new_metric',values=[\"low\",\"med\",\"high\"])\n",
@@ -158,14 +160,17 @@
158160
" class response_model(BaseModel):\n",
159161
" output: t.List[bool]\n",
160162
" reason: str\n",
161-
" \n",
163+
" traces = {}\n",
164+
" traces['input'] = kwargs\n",
162165
" response = llm.generate(prompt.format(**kwargs),response_model=response_model)\n",
166+
" traces['output'] = response.model_dump()\n",
163167
" total = sum(response.output)\n",
164168
" if total < 1:\n",
165169
" score = 'low'\n",
166170
" else:\n",
167171
" score = 'high'\n",
168-
" return score,\"reason\"\n",
172+
" \n",
173+
" return MetricResult(result=score,reason=response.reason,traces=traces)\n",
169174
"\n",
170175
"result = my_metric.score(response='my response') # result\n",
171176
"print(result)\n",

nbs/metric/numeric.ipynb

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@
147147
"source": [
148148
"\n",
149149
"#| eval: false\n",
150+
"from ragas_annotator.metric import MetricResult\n",
150151
"\n",
151152
"@numeric_metric(llm=llm,\n",
152153
" prompt=\"Evaluate if given answer is helpful\\n\\n{response}\",\n",
@@ -157,13 +158,16 @@
157158
" output: int\n",
158159
" reason: str\n",
159160
" \n",
161+
" traces = {}\n",
162+
" traces['input'] = kwargs\n",
160163
" response = llm.generate(prompt.format(**kwargs),response_model=response_model)\n",
164+
" traces['output'] = response.dict()\n",
161165
" total = response.output\n",
162166
" if total < 1:\n",
163167
" score = 0\n",
164168
" else:\n",
165169
" score = 10\n",
166-
" return score,\"reason\"\n",
170+
" return MetricResult(result=score,reason=response.reason,traces=traces)\n",
167171
"\n",
168172
"result = my_metric.score(response='my response') # result\n",
169173
"result # 10\n",

nbs/metric/ranking.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@
185185
"source": [
186186
"#| eval: false\n",
187187
"\n",
188+
"from ragas_annotator.metric import MetricResult\n",
188189
"\n",
189190
"@ranking_metric(\n",
190191
" llm=llm, # Your language model instance\n",
@@ -197,7 +198,7 @@
197198
" # For example, process the prompt (formatted with candidates) and produce a ranking.\n",
198199
" ranking = [1, 0, 2] # Dummy ranking: second candidate is best, then first, then third.\n",
199200
" reason = \"Ranked based on response clarity and detail.\"\n",
200-
" return ranking, reason\n",
201+
" return MetricResult(result=ranking, reason=reason)\n",
201202
"\n",
202203
"# Using the decorator-based ranking metric:\n",
203204
"result = my_ranking_metric.score(candidates=[\n",

nbs/metric/result.ipynb

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,14 @@
4646
" - RankingMetrics (list results)\n",
4747
" \"\"\"\n",
4848
" \n",
49-
" def __init__(self, result: t.Any, reason: t.Optional[str] = None):\n",
49+
" def __init__(self, result: t.Any, reason: t.Optional[str] = None, traces: t.Optional[t.Dict[str, t.Any]] = None):\n",
50+
" if traces is not None:\n",
51+
" invalid_keys = [key for key in traces.keys() if key not in {\"input\", \"output\"}]\n",
52+
" if invalid_keys:\n",
53+
" raise ValueError(f\"Invalid keys in traces: {invalid_keys}. Allowed keys are 'input' and 'output'.\")\n",
5054
" self._result = result\n",
5155
" self.reason = reason\n",
56+
" self.traces = traces\n",
5257
" \n",
5358
" def __repr__(self):\n",
5459
" return repr(self._result)\n",

0 commit comments

Comments
 (0)