Skip to content

Commit 62f9d43

Browse files
committed
Merge branch 'main' into feat/simpler_annotation
2 parents 3e09967 + b9fa013 commit 62f9d43

26 files changed

+3101
-525
lines changed

nbs/embedding/base.ipynb

Lines changed: 1150 additions & 0 deletions
Large diffs are not rendered by default.

nbs/llm/llm.ipynb

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"#| default_exp llm.llm"
10+
]
11+
},
12+
{
13+
"cell_type": "markdown",
14+
"metadata": {},
15+
"source": [
16+
"# LLM Interface for Ragas"
17+
]
18+
},
19+
{
20+
"cell_type": "code",
21+
"execution_count": null,
22+
"metadata": {},
23+
"outputs": [],
24+
"source": [
25+
"#| export\n",
26+
"\n",
27+
"import typing as t\n",
28+
"import asyncio\n",
29+
"import inspect\n",
30+
"import threading\n",
31+
"from pydantic import BaseModel\n",
32+
"import instructor\n",
33+
"\n",
34+
"T = t.TypeVar('T', bound=BaseModel)\n",
35+
"\n",
36+
"class RagasLLM:\n",
37+
" def __init__(self, provider: str, model: str, client: t.Any, **model_args):\n",
38+
" self.provider = provider.lower()\n",
39+
" self.model = model\n",
40+
" self.model_args = model_args or {}\n",
41+
" self.client = self._initialize_client(provider, client)\n",
42+
" # Check if client is async-capable at initialization\n",
43+
" self.is_async = self._check_client_async()\n",
44+
" \n",
45+
" def _check_client_async(self) -> bool:\n",
46+
" \"\"\"Determine if the client is async-capable.\"\"\"\n",
47+
" try:\n",
48+
" # Check if this is an async client by checking for a coroutine method\n",
49+
" if hasattr(self.client.chat.completions, 'create'):\n",
50+
" return inspect.iscoroutinefunction(self.client.chat.completions.create)\n",
51+
" return False\n",
52+
" except (AttributeError, TypeError):\n",
53+
" return False\n",
54+
" \n",
55+
" def _initialize_client(self, provider: str, client: t.Any) -> t.Any:\n",
56+
" provider = provider.lower()\n",
57+
" \n",
58+
" if provider == \"openai\":\n",
59+
" return instructor.from_openai(client)\n",
60+
" elif provider == \"anthropic\":\n",
61+
" return instructor.from_anthropic(client)\n",
62+
" elif provider == \"cohere\":\n",
63+
" return instructor.from_cohere(client)\n",
64+
" elif provider == \"gemini\":\n",
65+
" return instructor.from_gemini(client)\n",
66+
" elif provider == \"litellm\":\n",
67+
" return instructor.from_litellm(client)\n",
68+
" else:\n",
69+
" raise ValueError(f\"Unsupported provider: {provider}\")\n",
70+
" \n",
71+
" def _run_async_in_current_loop(self, coro):\n",
72+
" \"\"\"Run an async coroutine in the current event loop if possible.\n",
73+
" \n",
74+
" This handles Jupyter environments correctly by using a separate thread\n",
75+
" when a running event loop is detected.\n",
76+
" \"\"\"\n",
77+
" try:\n",
78+
" # Try to get the current event loop\n",
79+
" loop = asyncio.get_event_loop()\n",
80+
" \n",
81+
" if loop.is_running():\n",
82+
" # If the loop is already running (like in Jupyter notebooks),\n",
83+
" # we run the coroutine in a separate thread with its own event loop\n",
84+
" result_container = {'result': None, 'exception': None}\n",
85+
" \n",
86+
" def run_in_thread():\n",
87+
" # Create a new event loop for this thread\n",
88+
" new_loop = asyncio.new_event_loop()\n",
89+
" asyncio.set_event_loop(new_loop)\n",
90+
" try:\n",
91+
" # Run the coroutine in this thread's event loop\n",
92+
" result_container['result'] = new_loop.run_until_complete(coro)\n",
93+
" except Exception as e:\n",
94+
" # Capture any exceptions to re-raise in the main thread\n",
95+
" result_container['exception'] = e\n",
96+
" finally:\n",
97+
" # Clean up the event loop\n",
98+
" new_loop.close()\n",
99+
" \n",
100+
" # Start the thread and wait for it to complete\n",
101+
" thread = threading.Thread(target=run_in_thread)\n",
102+
" thread.start()\n",
103+
" thread.join()\n",
104+
" \n",
105+
" # Re-raise any exceptions that occurred in the thread\n",
106+
" if result_container['exception']:\n",
107+
" raise result_container['exception']\n",
108+
" \n",
109+
" return result_container['result']\n",
110+
" else:\n",
111+
" # Standard case - event loop exists but isn't running\n",
112+
" return loop.run_until_complete(coro)\n",
113+
" \n",
114+
" except RuntimeError:\n",
115+
" # If we get a runtime error about no event loop, create a new one\n",
116+
" loop = asyncio.new_event_loop()\n",
117+
" asyncio.set_event_loop(loop)\n",
118+
" try:\n",
119+
" return loop.run_until_complete(coro)\n",
120+
" finally:\n",
121+
" # Clean up\n",
122+
" loop.close()\n",
123+
" asyncio.set_event_loop(None)\n",
124+
" \n",
125+
" def generate(self, prompt: str, response_model: t.Type[T]) -> T:\n",
126+
" \"\"\"Generate a response using the configured LLM.\n",
127+
" \n",
128+
" For async clients, this will run the async method in the appropriate event loop.\n",
129+
" \"\"\"\n",
130+
" messages = [{\"role\": \"user\", \"content\": prompt}]\n",
131+
" \n",
132+
" # If client is async, use the appropriate method to run it\n",
133+
" if self.is_async:\n",
134+
" return self._run_async_in_current_loop(\n",
135+
" self.agenerate(prompt, response_model)\n",
136+
" )\n",
137+
" else:\n",
138+
" # Regular sync client, just call the method directly\n",
139+
" return self.client.chat.completions.create(\n",
140+
" model=self.model,\n",
141+
" messages=messages,\n",
142+
" response_model=response_model,\n",
143+
" **self.model_args,\n",
144+
" )\n",
145+
" \n",
146+
" async def agenerate(self, prompt: str, response_model: t.Type[T]) -> T:\n",
147+
" \"\"\"Asynchronously generate a response using the configured LLM.\"\"\"\n",
148+
" messages = [{\"role\": \"user\", \"content\": prompt}]\n",
149+
" \n",
150+
" # If client is not async, raise a helpful error\n",
151+
" if not self.is_async:\n",
152+
" raise TypeError(\n",
153+
" \"Cannot use agenerate() with a synchronous client. Use generate() instead.\"\n",
154+
" )\n",
155+
" \n",
156+
" # Regular async client, call the method directly\n",
157+
" return await self.client.chat.completions.create(\n",
158+
" model=self.model,\n",
159+
" messages=messages,\n",
160+
" response_model=response_model,\n",
161+
" **self.model_args,\n",
162+
" )\n",
163+
"\n",
164+
"def ragas_llm(provider: str, model: str, client: t.Any, **model_args) -> RagasLLM:\n",
165+
" return RagasLLM(provider=provider, client=client, model=model, **model_args)"
166+
]
167+
},
168+
{
169+
"cell_type": "markdown",
170+
"metadata": {},
171+
"source": [
172+
"### Example Usage"
173+
]
174+
},
175+
{
176+
"cell_type": "code",
177+
"execution_count": null,
178+
"metadata": {},
179+
"outputs": [],
180+
"source": [
181+
"#| eval: false\n",
182+
"\n",
183+
"from openai import OpenAI\n",
184+
"class Response(BaseModel):\n",
185+
" response: str\n",
186+
"\n",
187+
"llm = ragas_llm(provider=\"openai\",model=\"gpt-4o\",client=OpenAI())\n",
188+
"llm.generate(\"What is the capital of India?\",response_model=Response) #works fine\n",
189+
"\n",
190+
"try:\n",
191+
" await llm.agenerate(\"What is the capital of India?\", response_model=Response)\n",
192+
"except TypeError as e:\n",
193+
" assert isinstance(e, TypeError)\n",
194+
"#gives TypeError: object Response can't be used in 'await' expression\n"
195+
]
196+
},
197+
{
198+
"cell_type": "code",
199+
"execution_count": null,
200+
"metadata": {},
201+
"outputs": [
202+
{
203+
"data": {
204+
"text/plain": [
205+
"Response(response='The capital of India is New Delhi.')"
206+
]
207+
},
208+
"execution_count": null,
209+
"metadata": {},
210+
"output_type": "execute_result"
211+
}
212+
],
213+
"source": [
214+
"#| eval: false\n",
215+
"\n",
216+
"from openai import AsyncOpenAI\n",
217+
"\n",
218+
"llm = ragas_llm(provider=\"openai\",model=\"gpt-4o\",client=AsyncOpenAI())\n",
219+
"await llm.agenerate(\"What is the capital of India?\",response_model=Response)"
220+
]
221+
},
222+
{
223+
"cell_type": "code",
224+
"execution_count": null,
225+
"metadata": {},
226+
"outputs": [
227+
{
228+
"data": {
229+
"text/plain": [
230+
"Response(response='The capital of India is New Delhi.')"
231+
]
232+
},
233+
"execution_count": null,
234+
"metadata": {},
235+
"output_type": "execute_result"
236+
}
237+
],
238+
"source": [
239+
"#| eval: false\n",
240+
"\n",
241+
"from anthropic import Anthropic\n",
242+
"\n",
243+
"llm = ragas_llm(provider=\"anthropic\",model=\"claude-3-opus-20240229\",client=Anthropic(),max_tokens=1024)\n",
244+
"llm.generate(\"What is the capital of India?\",response_model=Response)"
245+
]
246+
}
247+
],
248+
"metadata": {
249+
"kernelspec": {
250+
"display_name": "python3",
251+
"language": "python",
252+
"name": "python3"
253+
}
254+
},
255+
"nbformat": 4,
256+
"nbformat_minor": 2
257+
}

0 commit comments

Comments
 (0)