Skip to content

Commit c14f806

Browse files
authored
Merge pull request #1 from explodinggradients/llm-as-judge
feat: metrics
2 parents 4d57b18 + 334ec8d commit c14f806

File tree

18 files changed

+1969
-3
lines changed

18 files changed

+1969
-3
lines changed

nbs/metric/base.ipynb

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "00ef8db1",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"#| default_exp metric.base"
11+
]
12+
},
13+
{
14+
"cell_type": "markdown",
15+
"id": "2eb8f806",
16+
"metadata": {},
17+
"source": [
18+
"# BaseMetric\n",
19+
"> base class for all type of metrics in ragas"
20+
]
21+
},
22+
{
23+
"cell_type": "code",
24+
"execution_count": null,
25+
"id": "e8ccff58",
26+
"metadata": {},
27+
"outputs": [
28+
{
29+
"name": "stderr",
30+
"output_type": "stream",
31+
"text": [
32+
"/opt/homebrew/Caskroom/miniforge/base/envs/random/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
33+
" from .autonotebook import tqdm as notebook_tqdm\n"
34+
]
35+
}
36+
],
37+
"source": [
38+
"#| export\n",
39+
"\n",
40+
"from abc import ABC, abstractmethod\n",
41+
"import asyncio\n",
42+
"from dataclasses import dataclass, field\n",
43+
"from pydantic import BaseModel\n",
44+
"import typing as t\n",
45+
"from ragas_annotator.metric import MetricResult\n",
46+
"from ragas_annotator.metric import LLM\n",
47+
"\n",
48+
"@dataclass\n",
49+
"class Metric(ABC):\n",
50+
" \"\"\"Base class for all metrics in the LLM evaluation library.\"\"\"\n",
51+
" name: str\n",
52+
" prompt: str\n",
53+
" llm: LLM\n",
54+
" _response_models: t.Dict[bool, t.Type[BaseModel]] = field(\n",
55+
" default_factory=dict, init=False, repr=False\n",
56+
" )\n",
57+
" \n",
58+
" @abstractmethod\n",
59+
" def _get_response_model(self, with_reasoning: bool) -> t.Type[BaseModel]:\n",
60+
" \"\"\"Get the appropriate response model.\"\"\"\n",
61+
" pass\n",
62+
"\n",
63+
" @abstractmethod\n",
64+
" def _ensemble(self, results: t.List[MetricResult]) -> MetricResult:\n",
65+
" pass\n",
66+
" \n",
67+
" \n",
68+
" def score(self, reasoning: bool = True, n: int = 1, **kwargs) -> t.Any:\n",
69+
" responses = []\n",
70+
" prompt_input = self.prompt.format(**kwargs)\n",
71+
" for _ in range(n):\n",
72+
" response = self.llm.generate(prompt_input, response_model = self._get_response_model(reasoning)) \n",
73+
" response = MetricResult(**response.model_dump())\n",
74+
" responses.append(response)\n",
75+
" return self._ensemble(responses)\n",
76+
"\n",
77+
"\n",
78+
" async def ascore(self, reasoning: bool = True, n: int = 1, **kwargs) -> MetricResult:\n",
79+
" responses = [] # Added missing initialization\n",
80+
" prompt_input = self.prompt.format(**kwargs)\n",
81+
" for _ in range(n):\n",
82+
" response = await self.llm.agenerate(prompt_input, response_model = self._get_response_model(reasoning))\n",
83+
" response = MetricResult(**response.model_dump()) # Fixed missing parentheses\n",
84+
" responses.append(response)\n",
85+
" return self._ensemble(responses)\n",
86+
" \n",
87+
" def batch_score(self, inputs: t.List[t.Dict[str, t.Any]], reasoning: bool = True, n: int = 1) -> t.List[t.Any]:\n",
88+
" return [self.score(reasoning, n, **input_dict) for input_dict in inputs]\n",
89+
" \n",
90+
" async def abatch_score(self, inputs: t.List[t.Dict[str, t.Any]], reasoning: bool = True, n: int = 1) -> t.List[MetricResult]:\n",
91+
" async_tasks = []\n",
92+
" for input_dict in inputs:\n",
93+
" # Add reasoning and n to the input parameters\n",
94+
" async_tasks.append(self.ascore(reasoning=reasoning, n=n, **input_dict))\n",
95+
" \n",
96+
" # Run all tasks concurrently and return results\n",
97+
" return await asyncio.gather(*async_tasks)"
98+
]
99+
},
100+
{
101+
"cell_type": "markdown",
102+
"id": "fc4b7458",
103+
"metadata": {},
104+
"source": [
105+
"### Example\n"
106+
]
107+
},
108+
{
109+
"cell_type": "code",
110+
"execution_count": null,
111+
"id": "fcf208fa",
112+
"metadata": {},
113+
"outputs": [],
114+
"source": [
115+
"#| eval: false\n",
116+
"\n",
117+
"@dataclass\n",
118+
"class CustomMetric(Metric):\n",
119+
" values: t.List[str] = field(default_factory=lambda: [\"pass\", \"fail\"])\n",
120+
" \n",
121+
" def _get_response_model(self, with_reasoning: bool) -> t.Type[BaseModel]:\n",
122+
" \"\"\"Get or create a response model based on reasoning parameter.\"\"\"\n",
123+
" \n",
124+
" class mymodel(BaseModel):\n",
125+
" result: int\n",
126+
" reason: t.Optional[str] = None\n",
127+
" \n",
128+
" return mymodel \n",
129+
"\n",
130+
" def _ensemble(self,results:t.List[MetricResult]) -> MetricResult:\n",
131+
" \n",
132+
" return results[0] # Placeholder for ensemble logic\n",
133+
"\n",
134+
"my_metric = CustomMetric(name=\"example\", prompt=\"What is the result of {input}?\", llm=LLM())\n",
135+
"my_metric.score(input=\"test\")"
136+
]
137+
}
138+
],
139+
"metadata": {},
140+
"nbformat": 4,
141+
"nbformat_minor": 2
142+
}

nbs/metric/decorator.ipynb

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"#| default_exp metric.decorator"
10+
]
11+
},
12+
{
13+
"cell_type": "markdown",
14+
"metadata": {},
15+
"source": [
16+
"# decorator factory for metrics\n",
17+
"> decorator factory for creating custom metrics"
18+
]
19+
},
20+
{
21+
"cell_type": "code",
22+
"execution_count": null,
23+
"metadata": {},
24+
"outputs": [],
25+
"source": [
26+
"#| export\n",
27+
"\n",
28+
"import typing as t\n",
29+
"import inspect\n",
30+
"import asyncio\n",
31+
"from dataclasses import dataclass\n",
32+
"from ragas_annotator.metric import MetricResult\n",
33+
"\n",
34+
"\n",
35+
"\n",
36+
"\n",
37+
"def create_metric_decorator(metric_class):\n",
38+
" \"\"\"\n",
39+
" Factory function that creates decorator factories for different metric types.\n",
40+
" \n",
41+
" Args:\n",
42+
" metric_class: The metric class to use (DiscreteMetrics, NumericMetrics, etc.)\n",
43+
" \n",
44+
" Returns:\n",
45+
" A decorator factory function for the specified metric type\n",
46+
" \"\"\"\n",
47+
" def decorator_factory(llm, prompt, name: t.Optional[str] = None, **metric_params):\n",
48+
" \"\"\"\n",
49+
" Creates a decorator that wraps a function into a metric instance.\n",
50+
" \n",
51+
" Args:\n",
52+
" llm: The language model instance to use\n",
53+
" prompt: The prompt template\n",
54+
" name: Optional name for the metric (defaults to function name)\n",
55+
" **metric_params: Additional parameters specific to the metric type\n",
56+
" (values for DiscreteMetrics, range for NumericMetrics, etc.)\n",
57+
" \n",
58+
" Returns:\n",
59+
" A decorator function\n",
60+
" \"\"\"\n",
61+
" def decorator(func):\n",
62+
" # Get metric name and check if function is async\n",
63+
" metric_name = name or func.__name__\n",
64+
" is_async = inspect.iscoroutinefunction(func)\n",
65+
" \n",
66+
" @dataclass\n",
67+
" class CustomMetric(metric_class):\n",
68+
" def _extract_result(self, result, reasoning: bool):\n",
69+
" \"\"\"Extract score and reason from the result.\"\"\"\n",
70+
" if isinstance(result, tuple) and len(result) == 2:\n",
71+
" score, reason = result\n",
72+
" else:\n",
73+
" score, reason = result, None\n",
74+
" \n",
75+
" # Use \"result\" instead of \"score\" for the new MetricResult implementation\n",
76+
" return MetricResult(result=score, reason=reason if reasoning else None)\n",
77+
" \n",
78+
" def _run_sync_in_async(self, func, *args, **kwargs):\n",
79+
" \"\"\"Run a synchronous function in an async context.\"\"\"\n",
80+
" # For sync functions, just run them normally\n",
81+
" return func(*args, **kwargs)\n",
82+
" \n",
83+
" def _execute_metric(self, is_async_execution, reasoning, **kwargs):\n",
84+
" \"\"\"Execute the metric function with proper async handling.\"\"\"\n",
85+
" try:\n",
86+
" if is_async:\n",
87+
" # Async function implementation\n",
88+
" if is_async_execution:\n",
89+
" # In async context, await the function directly\n",
90+
" result = func(self.llm, self.prompt, **kwargs)\n",
91+
" else:\n",
92+
" # In sync context, run the async function in an event loop\n",
93+
" try:\n",
94+
" loop = asyncio.get_event_loop()\n",
95+
" except RuntimeError:\n",
96+
" loop = asyncio.new_event_loop()\n",
97+
" asyncio.set_event_loop(loop)\n",
98+
" result = loop.run_until_complete(func(self.llm, self.prompt, **kwargs))\n",
99+
" else:\n",
100+
" # Sync function implementation\n",
101+
" result = func(self.llm, self.prompt, **kwargs)\n",
102+
" \n",
103+
" return self._extract_result(result, reasoning)\n",
104+
" except Exception as e:\n",
105+
" # Handle errors gracefully\n",
106+
" error_msg = f\"Error executing metric {self.name}: {str(e)}\"\n",
107+
" return MetricResult(result=None, reason=error_msg)\n",
108+
" \n",
109+
" def score(self, reasoning: bool = True, n: int = 1, **kwargs):\n",
110+
" \"\"\"Synchronous scoring method.\"\"\"\n",
111+
" return self._execute_metric(is_async_execution=False, reasoning=reasoning, **kwargs)\n",
112+
" \n",
113+
" async def ascore(self, reasoning: bool = True, n: int = 1, **kwargs):\n",
114+
" \"\"\"Asynchronous scoring method.\"\"\"\n",
115+
" if is_async:\n",
116+
" # For async functions, await the result\n",
117+
" result = await func(self.llm, self.prompt, **kwargs)\n",
118+
" return self._extract_result(result, reasoning)\n",
119+
" else:\n",
120+
" # For sync functions, run normally\n",
121+
" result = self._run_sync_in_async(func, self.llm, self.prompt, **kwargs)\n",
122+
" return self._extract_result(result, reasoning)\n",
123+
" \n",
124+
" # Create the metric instance with all parameters\n",
125+
" metric_instance = CustomMetric(\n",
126+
" name=metric_name,\n",
127+
" prompt=prompt,\n",
128+
" llm=llm,\n",
129+
" **metric_params\n",
130+
" )\n",
131+
" \n",
132+
" # Preserve metadata\n",
133+
" metric_instance.__name__ = metric_name\n",
134+
" metric_instance.__doc__ = func.__doc__\n",
135+
" \n",
136+
" return metric_instance\n",
137+
" \n",
138+
" return decorator\n",
139+
" \n",
140+
" return decorator_factory\n",
141+
"\n",
142+
"\n"
143+
]
144+
},
145+
{
146+
"cell_type": "markdown",
147+
"metadata": {},
148+
"source": [
149+
"### Example usage\n"
150+
]
151+
},
152+
{
153+
"cell_type": "code",
154+
"execution_count": null,
155+
"metadata": {},
156+
"outputs": [
157+
{
158+
"name": "stdout",
159+
"output_type": "stream",
160+
"text": [
161+
"high\n",
162+
"reason\n"
163+
]
164+
}
165+
],
166+
"source": [
167+
"#| eval: false\n",
168+
"\n",
169+
"\n",
170+
"from ragas_annotator.metric import DiscreteMetric\n",
171+
"from ragas_annotator.metric.llm import LLM\n",
172+
"from pydantic import BaseModel\n",
173+
"\n",
174+
"discrete_metric = create_metric_decorator(DiscreteMetric)\n",
175+
"\n",
176+
"@discrete_metric(llm=LLM(),\n",
177+
" prompt=\"Evaluate if given answer is helpful\\n\\n{response}\",\n",
178+
" name='new_metric',values=[\"low\",\"med\",\"high\"])\n",
179+
"def my_metric(llm,prompt,**kwargs):\n",
180+
"\n",
181+
" class response_model(BaseModel):\n",
182+
" output: t.List[bool]\n",
183+
" reason: str\n",
184+
" \n",
185+
" response = llm.generate(prompt.format(**kwargs),response_model=response_model)\n",
186+
" total = sum(response.output)\n",
187+
" if total < 1:\n",
188+
" score = 'low'\n",
189+
" else:\n",
190+
" score = 'high'\n",
191+
" return score,\"reason\"\n",
192+
"\n",
193+
"result = my_metric.score(response='my response') # result\n",
194+
"print(result)\n",
195+
"print(result.reason)"
196+
]
197+
},
198+
{
199+
"cell_type": "code",
200+
"execution_count": null,
201+
"metadata": {},
202+
"outputs": [],
203+
"source": []
204+
}
205+
],
206+
"metadata": {
207+
"kernelspec": {
208+
"display_name": "python3",
209+
"language": "python",
210+
"name": "python3"
211+
}
212+
},
213+
"nbformat": 4,
214+
"nbformat_minor": 2
215+
}

0 commit comments

Comments
 (0)