Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions nbs/embedding/base.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#| default_exp embedding.base"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## RagasEmbedding"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"\n",
"import typing as t\n",
"from dataclasses import dataclass\n",
"\n",
"\n",
"@dataclass\n",
"class RagasEmbedding:\n",
" client: t.Any\n",
" model: str\n",
" \n",
" def embed_text(self,text:str,**kwargs: t.Any) -> t.List[float]:\n",
" \n",
" return self.client.embeddings.create(input=text, model=self.model, **kwargs).data[0].embedding\n",
" \n",
" async def aembed_text(self,text:str,**kwargs: t.Any):\n",
" \n",
" await self.client.embeddings.create(input=text, model=self.model, **kwargs).data[0].embedding\n",
" \n",
" \n",
" def embed_document(self,documents:t.List[str],**kwargs: t.Any) -> t.List[t.List[float]]:\n",
" embeddings = self.client.embeddings.create(input=documents, model=self.model, **kwargs)\n",
" return [embedding.embedding for embedding in embeddings.data]\n",
" \n",
" async def aembed_document(self,documents:t.List[str],**kwargs: t.Any) -> t.List[t.List[float]]:\n",
" embeddings = await self.client.embeddings.create(input=documents, model=self.model, **kwargs)\n",
" return [embedding.embedding for embedding in embeddings.data]\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example Usage"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| eval: false\n",
"\n",
"from openai import OpenAI\n",
"embedding = RagasEmbedding(client=OpenAI(),model=\"text-embedding-3-small\")\n",
"embedding.embed_text(\"Hello, world!\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "python3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
257 changes: 257 additions & 0 deletions nbs/llm/llm.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| default_exp llm.llm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# LLM Interface for Ragas"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"\n",
"import typing as t\n",
"import asyncio\n",
"import inspect\n",
"import threading\n",
"from pydantic import BaseModel\n",
"import instructor\n",
"\n",
"T = t.TypeVar('T', bound=BaseModel)\n",
"\n",
"class RagasLLM:\n",
" def __init__(self, provider: str, model: str, client: t.Any, **model_args):\n",
" self.provider = provider.lower()\n",
" self.model = model\n",
" self.model_args = model_args or {}\n",
" self.client = self._initialize_client(provider, client)\n",
" # Check if client is async-capable at initialization\n",
" self.is_async = self._check_client_async()\n",
" \n",
" def _check_client_async(self) -> bool:\n",
" \"\"\"Determine if the client is async-capable.\"\"\"\n",
" try:\n",
" # Check if this is an async client by checking for a coroutine method\n",
" if hasattr(self.client.chat.completions, 'create'):\n",
" return inspect.iscoroutinefunction(self.client.chat.completions.create)\n",
" return False\n",
" except (AttributeError, TypeError):\n",
" return False\n",
" \n",
" def _initialize_client(self, provider: str, client: t.Any) -> t.Any:\n",
" provider = provider.lower()\n",
" \n",
" if provider == \"openai\":\n",
" return instructor.from_openai(client)\n",
" elif provider == \"anthropic\":\n",
" return instructor.from_anthropic(client)\n",
" elif provider == \"cohere\":\n",
" return instructor.from_cohere(client)\n",
" elif provider == \"gemini\":\n",
" return instructor.from_gemini(client)\n",
" elif provider == \"litellm\":\n",
" return instructor.from_litellm(client)\n",
" else:\n",
" raise ValueError(f\"Unsupported provider: {provider}\")\n",
" \n",
" def _run_async_in_current_loop(self, coro):\n",
" \"\"\"Run an async coroutine in the current event loop if possible.\n",
" \n",
" This handles Jupyter environments correctly by using a separate thread\n",
" when a running event loop is detected.\n",
" \"\"\"\n",
" try:\n",
" # Try to get the current event loop\n",
" loop = asyncio.get_event_loop()\n",
" \n",
" if loop.is_running():\n",
" # If the loop is already running (like in Jupyter notebooks),\n",
" # we run the coroutine in a separate thread with its own event loop\n",
" result_container = {'result': None, 'exception': None}\n",
" \n",
" def run_in_thread():\n",
" # Create a new event loop for this thread\n",
" new_loop = asyncio.new_event_loop()\n",
" asyncio.set_event_loop(new_loop)\n",
" try:\n",
" # Run the coroutine in this thread's event loop\n",
" result_container['result'] = new_loop.run_until_complete(coro)\n",
" except Exception as e:\n",
" # Capture any exceptions to re-raise in the main thread\n",
" result_container['exception'] = e\n",
" finally:\n",
" # Clean up the event loop\n",
" new_loop.close()\n",
" \n",
" # Start the thread and wait for it to complete\n",
" thread = threading.Thread(target=run_in_thread)\n",
" thread.start()\n",
" thread.join()\n",
" \n",
" # Re-raise any exceptions that occurred in the thread\n",
" if result_container['exception']:\n",
" raise result_container['exception']\n",
" \n",
" return result_container['result']\n",
" else:\n",
" # Standard case - event loop exists but isn't running\n",
" return loop.run_until_complete(coro)\n",
" \n",
" except RuntimeError:\n",
" # If we get a runtime error about no event loop, create a new one\n",
" loop = asyncio.new_event_loop()\n",
" asyncio.set_event_loop(loop)\n",
" try:\n",
" return loop.run_until_complete(coro)\n",
" finally:\n",
" # Clean up\n",
" loop.close()\n",
" asyncio.set_event_loop(None)\n",
" \n",
" def generate(self, prompt: str, response_model: t.Type[T]) -> T:\n",
" \"\"\"Generate a response using the configured LLM.\n",
" \n",
" For async clients, this will run the async method in the appropriate event loop.\n",
" \"\"\"\n",
" messages = [{\"role\": \"user\", \"content\": prompt}]\n",
" \n",
" # If client is async, use the appropriate method to run it\n",
" if self.is_async:\n",
" return self._run_async_in_current_loop(\n",
" self.agenerate(prompt, response_model)\n",
" )\n",
" else:\n",
" # Regular sync client, just call the method directly\n",
" return self.client.chat.completions.create(\n",
" model=self.model,\n",
" messages=messages,\n",
" response_model=response_model,\n",
" **self.model_args,\n",
" )\n",
" \n",
" async def agenerate(self, prompt: str, response_model: t.Type[T]) -> T:\n",
" \"\"\"Asynchronously generate a response using the configured LLM.\"\"\"\n",
" messages = [{\"role\": \"user\", \"content\": prompt}]\n",
" \n",
" # If client is not async, raise a helpful error\n",
" if not self.is_async:\n",
" raise TypeError(\n",
" \"Cannot use agenerate() with a synchronous client. Use generate() instead.\"\n",
" )\n",
" \n",
" # Regular async client, call the method directly\n",
" return await self.client.chat.completions.create(\n",
" model=self.model,\n",
" messages=messages,\n",
" response_model=response_model,\n",
" **self.model_args,\n",
" )\n",
"\n",
"def ragas_llm(provider: str, model: str, client: t.Any, **model_args) -> RagasLLM:\n",
" return RagasLLM(provider=provider, client=client, model=model, **model_args)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example Usage"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| eval: false\n",
"\n",
"from openai import OpenAI\n",
"class Response(BaseModel):\n",
" response: str\n",
"\n",
"llm = ragas_llm(provider=\"openai\",model=\"gpt-4o\",client=OpenAI())\n",
"llm.generate(\"What is the capital of India?\",response_model=Response) #works fine\n",
"\n",
"try:\n",
" await llm.agenerate(\"What is the capital of India?\", response_model=Response)\n",
"except TypeError as e:\n",
" assert isinstance(e, TypeError)\n",
"#gives TypeError: object Response can't be used in 'await' expression\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Response(response='The capital of India is New Delhi.')"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#| eval: false\n",
"\n",
"from openai import AsyncOpenAI\n",
"\n",
"llm = ragas_llm(provider=\"openai\",model=\"gpt-4o\",client=AsyncOpenAI())\n",
"await llm.agenerate(\"What is the capital of India?\",response_model=Response)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Response(response='The capital of India is New Delhi.')"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#| eval: false\n",
"\n",
"from anthropic import Anthropic\n",
"\n",
"llm = ragas_llm(provider=\"anthropic\",model=\"claude-3-opus-20240229\",client=Anthropic(),max_tokens=1024)\n",
"llm.generate(\"What is the capital of India?\",response_model=Response)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "python3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading