Skip to content

Commit 6142800

Browse files
authored
Merge pull request #3 from explodinggradients/llm
feat: ragas llm
2 parents cc567ed + cfe0c73 commit 6142800

File tree

16 files changed

+509
-138
lines changed

16 files changed

+509
-138
lines changed

nbs/llm/llm.ipynb

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"#| default_exp llm.llm"
10+
]
11+
},
12+
{
13+
"cell_type": "markdown",
14+
"metadata": {},
15+
"source": [
16+
"# LLM Interface for Ragas"
17+
]
18+
},
19+
{
20+
"cell_type": "code",
21+
"execution_count": null,
22+
"metadata": {},
23+
"outputs": [],
24+
"source": [
25+
"#| export\n",
26+
"\n",
27+
"import typing as t\n",
28+
"import asyncio\n",
29+
"import inspect\n",
30+
"import threading\n",
31+
"from pydantic import BaseModel\n",
32+
"import instructor\n",
33+
"\n",
34+
"T = t.TypeVar('T', bound=BaseModel)\n",
35+
"\n",
36+
"class RagasLLM:\n",
37+
" def __init__(self, provider: str, model: str, client: t.Any, **model_args):\n",
38+
" self.provider = provider.lower()\n",
39+
" self.model = model\n",
40+
" self.model_args = model_args or {}\n",
41+
" self.client = self._initialize_client(provider, client)\n",
42+
" # Check if client is async-capable at initialization\n",
43+
" self.is_async = self._check_client_async()\n",
44+
" \n",
45+
" def _check_client_async(self) -> bool:\n",
46+
" \"\"\"Determine if the client is async-capable.\"\"\"\n",
47+
" try:\n",
48+
" # Check if this is an async client by checking for a coroutine method\n",
49+
" if hasattr(self.client.chat.completions, 'create'):\n",
50+
" return inspect.iscoroutinefunction(self.client.chat.completions.create)\n",
51+
" return False\n",
52+
" except (AttributeError, TypeError):\n",
53+
" return False\n",
54+
" \n",
55+
" def _initialize_client(self, provider: str, client: t.Any) -> t.Any:\n",
56+
" provider = provider.lower()\n",
57+
" \n",
58+
" if provider == \"openai\":\n",
59+
" return instructor.from_openai(client)\n",
60+
" elif provider == \"anthropic\":\n",
61+
" return instructor.from_anthropic(client)\n",
62+
" elif provider == \"cohere\":\n",
63+
" return instructor.from_cohere(client)\n",
64+
" elif provider == \"gemini\":\n",
65+
" return instructor.from_gemini(client)\n",
66+
" elif provider == \"litellm\":\n",
67+
" return instructor.from_litellm(client)\n",
68+
" else:\n",
69+
" raise ValueError(f\"Unsupported provider: {provider}\")\n",
70+
" \n",
71+
" def _run_async_in_current_loop(self, coro):\n",
72+
" \"\"\"Run an async coroutine in the current event loop if possible.\n",
73+
" \n",
74+
" This handles Jupyter environments correctly by using a separate thread\n",
75+
" when a running event loop is detected.\n",
76+
" \"\"\"\n",
77+
" try:\n",
78+
" # Try to get the current event loop\n",
79+
" loop = asyncio.get_event_loop()\n",
80+
" \n",
81+
" if loop.is_running():\n",
82+
" # If the loop is already running (like in Jupyter notebooks),\n",
83+
" # we run the coroutine in a separate thread with its own event loop\n",
84+
" result_container = {'result': None, 'exception': None}\n",
85+
" \n",
86+
" def run_in_thread():\n",
87+
" # Create a new event loop for this thread\n",
88+
" new_loop = asyncio.new_event_loop()\n",
89+
" asyncio.set_event_loop(new_loop)\n",
90+
" try:\n",
91+
" # Run the coroutine in this thread's event loop\n",
92+
" result_container['result'] = new_loop.run_until_complete(coro)\n",
93+
" except Exception as e:\n",
94+
" # Capture any exceptions to re-raise in the main thread\n",
95+
" result_container['exception'] = e\n",
96+
" finally:\n",
97+
" # Clean up the event loop\n",
98+
" new_loop.close()\n",
99+
" \n",
100+
" # Start the thread and wait for it to complete\n",
101+
" thread = threading.Thread(target=run_in_thread)\n",
102+
" thread.start()\n",
103+
" thread.join()\n",
104+
" \n",
105+
" # Re-raise any exceptions that occurred in the thread\n",
106+
" if result_container['exception']:\n",
107+
" raise result_container['exception']\n",
108+
" \n",
109+
" return result_container['result']\n",
110+
" else:\n",
111+
" # Standard case - event loop exists but isn't running\n",
112+
" return loop.run_until_complete(coro)\n",
113+
" \n",
114+
" except RuntimeError:\n",
115+
" # If we get a runtime error about no event loop, create a new one\n",
116+
" loop = asyncio.new_event_loop()\n",
117+
" asyncio.set_event_loop(loop)\n",
118+
" try:\n",
119+
" return loop.run_until_complete(coro)\n",
120+
" finally:\n",
121+
" # Clean up\n",
122+
" loop.close()\n",
123+
" asyncio.set_event_loop(None)\n",
124+
" \n",
125+
" def generate(self, prompt: str, response_model: t.Type[T]) -> T:\n",
126+
" \"\"\"Generate a response using the configured LLM.\n",
127+
" \n",
128+
" For async clients, this will run the async method in the appropriate event loop.\n",
129+
" \"\"\"\n",
130+
" messages = [{\"role\": \"user\", \"content\": prompt}]\n",
131+
" \n",
132+
" # If client is async, use the appropriate method to run it\n",
133+
" if self.is_async:\n",
134+
" return self._run_async_in_current_loop(\n",
135+
" self.agenerate(prompt, response_model)\n",
136+
" )\n",
137+
" else:\n",
138+
" # Regular sync client, just call the method directly\n",
139+
" return self.client.chat.completions.create(\n",
140+
" model=self.model,\n",
141+
" messages=messages,\n",
142+
" response_model=response_model,\n",
143+
" **self.model_args,\n",
144+
" )\n",
145+
" \n",
146+
" async def agenerate(self, prompt: str, response_model: t.Type[T]) -> T:\n",
147+
" \"\"\"Asynchronously generate a response using the configured LLM.\"\"\"\n",
148+
" messages = [{\"role\": \"user\", \"content\": prompt}]\n",
149+
" \n",
150+
" # If client is not async, raise a helpful error\n",
151+
" if not self.is_async:\n",
152+
" raise TypeError(\n",
153+
" \"Cannot use agenerate() with a synchronous client. Use generate() instead.\"\n",
154+
" )\n",
155+
" \n",
156+
" # Regular async client, call the method directly\n",
157+
" return await self.client.chat.completions.create(\n",
158+
" model=self.model,\n",
159+
" messages=messages,\n",
160+
" response_model=response_model,\n",
161+
" **self.model_args,\n",
162+
" )\n",
163+
"\n",
164+
"def ragas_llm(provider: str, model: str, client: t.Any, **model_args) -> RagasLLM:\n",
165+
" return RagasLLM(provider=provider, client=client, model=model, **model_args)"
166+
]
167+
},
168+
{
169+
"cell_type": "markdown",
170+
"metadata": {},
171+
"source": [
172+
"### Example Usage"
173+
]
174+
},
175+
{
176+
"cell_type": "code",
177+
"execution_count": null,
178+
"metadata": {},
179+
"outputs": [],
180+
"source": [
181+
"#| eval: false\n",
182+
"\n",
183+
"from openai import OpenAI\n",
184+
"class Response(BaseModel):\n",
185+
" response: str\n",
186+
"\n",
187+
"llm = ragas_llm(provider=\"openai\",model=\"gpt-4o\",client=OpenAI())\n",
188+
"llm.generate(\"What is the capital of India?\",response_model=Response) #works fine\n",
189+
"\n",
190+
"try:\n",
191+
" await llm.agenerate(\"What is the capital of India?\", response_model=Response)\n",
192+
"except TypeError as e:\n",
193+
" assert isinstance(e, TypeError)\n",
194+
"#gives TypeError: object Response can't be used in 'await' expression\n"
195+
]
196+
},
197+
{
198+
"cell_type": "code",
199+
"execution_count": null,
200+
"metadata": {},
201+
"outputs": [
202+
{
203+
"data": {
204+
"text/plain": [
205+
"Response(response='The capital of India is New Delhi.')"
206+
]
207+
},
208+
"execution_count": null,
209+
"metadata": {},
210+
"output_type": "execute_result"
211+
}
212+
],
213+
"source": [
214+
"#| eval: false\n",
215+
"\n",
216+
"from openai import AsyncOpenAI\n",
217+
"\n",
218+
"llm = ragas_llm(provider=\"openai\",model=\"gpt-4o\",client=AsyncOpenAI())\n",
219+
"await llm.agenerate(\"What is the capital of India?\",response_model=Response)"
220+
]
221+
},
222+
{
223+
"cell_type": "code",
224+
"execution_count": null,
225+
"metadata": {},
226+
"outputs": [
227+
{
228+
"data": {
229+
"text/plain": [
230+
"Response(response='The capital of India is New Delhi.')"
231+
]
232+
},
233+
"execution_count": null,
234+
"metadata": {},
235+
"output_type": "execute_result"
236+
}
237+
],
238+
"source": [
239+
"#| eval: false\n",
240+
"\n",
241+
"from anthropic import Anthropic\n",
242+
"\n",
243+
"llm = ragas_llm(provider=\"anthropic\",model=\"claude-3-opus-20240229\",client=Anthropic(),max_tokens=1024)\n",
244+
"llm.generate(\"What is the capital of India?\",response_model=Response)"
245+
]
246+
}
247+
],
248+
"metadata": {
249+
"kernelspec": {
250+
"display_name": "python3",
251+
"language": "python",
252+
"name": "python3"
253+
}
254+
},
255+
"nbformat": 4,
256+
"nbformat_minor": 2
257+
}

nbs/metric/base.ipynb

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,14 @@
4343
"from pydantic import BaseModel\n",
4444
"import typing as t\n",
4545
"from ragas_annotator.metric import MetricResult\n",
46-
"from ragas_annotator.metric import LLM\n",
46+
"from ragas_annotator.llm import RagasLLM\n",
4747
"\n",
4848
"@dataclass\n",
4949
"class Metric(ABC):\n",
5050
" \"\"\"Base class for all metrics in the LLM evaluation library.\"\"\"\n",
5151
" name: str\n",
5252
" prompt: str\n",
53-
" llm: LLM\n",
53+
" llm: RagasLLM\n",
5454
" _response_models: t.Dict[bool, t.Type[BaseModel]] = field(\n",
5555
" default_factory=dict, init=False, repr=False\n",
5656
" )\n",
@@ -114,6 +114,11 @@
114114
"source": [
115115
"#| eval: false\n",
116116
"\n",
117+
"from ragas_annotator.llm import ragas_llm\n",
118+
"from openai import OpenAI\n",
119+
"\n",
120+
"llm = ragas_llm(provider=\"openai\",model=\"gpt-4o\",client=OpenAI())\n",
121+
"\n",
117122
"@dataclass\n",
118123
"class CustomMetric(Metric):\n",
119124
" values: t.List[str] = field(default_factory=lambda: [\"pass\", \"fail\"])\n",
@@ -131,12 +136,18 @@
131136
" \n",
132137
" return results[0] # Placeholder for ensemble logic\n",
133138
"\n",
134-
"my_metric = CustomMetric(name=\"example\", prompt=\"What is the result of {input}?\", llm=LLM())\n",
139+
"my_metric = CustomMetric(name=\"example\", prompt=\"What is the result of {input}?\", llm=llm)\n",
135140
"my_metric.score(input=\"test\")"
136141
]
137142
}
138143
],
139-
"metadata": {},
144+
"metadata": {
145+
"kernelspec": {
146+
"display_name": "python3",
147+
"language": "python",
148+
"name": "python3"
149+
}
150+
},
140151
"nbformat": 4,
141152
"nbformat_minor": 2
142153
}

nbs/metric/decorator.ipynb

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"import asyncio\n",
3131
"from dataclasses import dataclass\n",
3232
"from ragas_annotator.metric import MetricResult\n",
33+
"from ragas_annotator.llm import RagasLLM\n",
3334
"\n",
3435
"\n",
3536
"\n",
@@ -44,7 +45,7 @@
4445
" Returns:\n",
4546
" A decorator factory function for the specified metric type\n",
4647
" \"\"\"\n",
47-
" def decorator_factory(llm, prompt, name: t.Optional[str] = None, **metric_params):\n",
48+
" def decorator_factory(llm:RagasLLM, prompt, name: t.Optional[str] = None, **metric_params):\n",
4849
" \"\"\"\n",
4950
" Creates a decorator that wraps a function into a metric instance.\n",
5051
" \n",
@@ -168,12 +169,16 @@
168169
"\n",
169170
"\n",
170171
"from ragas_annotator.metric import DiscreteMetric\n",
171-
"from ragas_annotator.metric.llm import LLM\n",
172172
"from pydantic import BaseModel\n",
173173
"\n",
174+
"from ragas_annotator.llm import ragas_llm\n",
175+
"from openai import OpenAI\n",
176+
"\n",
177+
"llm = ragas_llm(provider=\"openai\",model=\"gpt-4o\",client=OpenAI())\n",
178+
"\n",
174179
"discrete_metric = create_metric_decorator(DiscreteMetric)\n",
175180
"\n",
176-
"@discrete_metric(llm=LLM(),\n",
181+
"@discrete_metric(llm=llm,\n",
177182
" prompt=\"Evaluate if given answer is helpful\\n\\n{response}\",\n",
178183
" name='new_metric',values=[\"low\",\"med\",\"high\"])\n",
179184
"def my_metric(llm,prompt,**kwargs):\n",

0 commit comments

Comments
 (0)