Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 257 additions & 0 deletions nbs/llm/llm.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| default_exp llm.llm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# LLM Interface for Ragas"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"\n",
"import typing as t\n",
"import asyncio\n",
"import inspect\n",
"import threading\n",
"from pydantic import BaseModel\n",
"import instructor\n",
"\n",
"T = t.TypeVar('T', bound=BaseModel)\n",
"\n",
"class RagasLLM:\n",
" def __init__(self, provider: str, model: str, client: t.Any, **model_args):\n",
" self.provider = provider.lower()\n",
" self.model = model\n",
" self.model_args = model_args or {}\n",
" self.client = self._initialize_client(provider, client)\n",
" # Check if client is async-capable at initialization\n",
" self.is_async = self._check_client_async()\n",
" \n",
" def _check_client_async(self) -> bool:\n",
" \"\"\"Determine if the client is async-capable.\"\"\"\n",
" try:\n",
" # Check if this is an async client by checking for a coroutine method\n",
" if hasattr(self.client.chat.completions, 'create'):\n",
" return inspect.iscoroutinefunction(self.client.chat.completions.create)\n",
" return False\n",
" except (AttributeError, TypeError):\n",
" return False\n",
" \n",
" def _initialize_client(self, provider: str, client: t.Any) -> t.Any:\n",
" provider = provider.lower()\n",
" \n",
" if provider == \"openai\":\n",
" return instructor.from_openai(client)\n",
" elif provider == \"anthropic\":\n",
" return instructor.from_anthropic(client)\n",
" elif provider == \"cohere\":\n",
" return instructor.from_cohere(client)\n",
" elif provider == \"gemini\":\n",
" return instructor.from_gemini(client)\n",
" elif provider == \"litellm\":\n",
" return instructor.from_litellm(client)\n",
" else:\n",
" raise ValueError(f\"Unsupported provider: {provider}\")\n",
" \n",
" def _run_async_in_current_loop(self, coro):\n",
" \"\"\"Run an async coroutine in the current event loop if possible.\n",
" \n",
" This handles Jupyter environments correctly by using a separate thread\n",
" when a running event loop is detected.\n",
" \"\"\"\n",
" try:\n",
" # Try to get the current event loop\n",
" loop = asyncio.get_event_loop()\n",
" \n",
" if loop.is_running():\n",
" # If the loop is already running (like in Jupyter notebooks),\n",
" # we run the coroutine in a separate thread with its own event loop\n",
" result_container = {'result': None, 'exception': None}\n",
" \n",
" def run_in_thread():\n",
" # Create a new event loop for this thread\n",
" new_loop = asyncio.new_event_loop()\n",
" asyncio.set_event_loop(new_loop)\n",
" try:\n",
" # Run the coroutine in this thread's event loop\n",
" result_container['result'] = new_loop.run_until_complete(coro)\n",
" except Exception as e:\n",
" # Capture any exceptions to re-raise in the main thread\n",
" result_container['exception'] = e\n",
" finally:\n",
" # Clean up the event loop\n",
" new_loop.close()\n",
" \n",
" # Start the thread and wait for it to complete\n",
" thread = threading.Thread(target=run_in_thread)\n",
" thread.start()\n",
" thread.join()\n",
" \n",
" # Re-raise any exceptions that occurred in the thread\n",
" if result_container['exception']:\n",
" raise result_container['exception']\n",
" \n",
" return result_container['result']\n",
" else:\n",
" # Standard case - event loop exists but isn't running\n",
" return loop.run_until_complete(coro)\n",
" \n",
" except RuntimeError:\n",
" # If we get a runtime error about no event loop, create a new one\n",
" loop = asyncio.new_event_loop()\n",
" asyncio.set_event_loop(loop)\n",
" try:\n",
" return loop.run_until_complete(coro)\n",
" finally:\n",
" # Clean up\n",
" loop.close()\n",
" asyncio.set_event_loop(None)\n",
" \n",
" def generate(self, prompt: str, response_model: t.Type[T]) -> T:\n",
" \"\"\"Generate a response using the configured LLM.\n",
" \n",
" For async clients, this will run the async method in the appropriate event loop.\n",
" \"\"\"\n",
" messages = [{\"role\": \"user\", \"content\": prompt}]\n",
" \n",
" # If client is async, use the appropriate method to run it\n",
" if self.is_async:\n",
" return self._run_async_in_current_loop(\n",
" self.agenerate(prompt, response_model)\n",
" )\n",
" else:\n",
" # Regular sync client, just call the method directly\n",
" return self.client.chat.completions.create(\n",
" model=self.model,\n",
" messages=messages,\n",
" response_model=response_model,\n",
" **self.model_args,\n",
" )\n",
" \n",
" async def agenerate(self, prompt: str, response_model: t.Type[T]) -> T:\n",
" \"\"\"Asynchronously generate a response using the configured LLM.\"\"\"\n",
" messages = [{\"role\": \"user\", \"content\": prompt}]\n",
" \n",
" # If client is not async, raise a helpful error\n",
" if not self.is_async:\n",
" raise TypeError(\n",
" \"Cannot use agenerate() with a synchronous client. Use generate() instead.\"\n",
" )\n",
" \n",
" # Regular async client, call the method directly\n",
" return await self.client.chat.completions.create(\n",
" model=self.model,\n",
" messages=messages,\n",
" response_model=response_model,\n",
" **self.model_args,\n",
" )\n",
"\n",
"def ragas_llm(provider: str, model: str, client: t.Any, **model_args) -> RagasLLM:\n",
" return RagasLLM(provider=provider, client=client, model=model, **model_args)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example Usage"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| eval: false\n",
"\n",
"from openai import OpenAI\n",
"class Response(BaseModel):\n",
" response: str\n",
"\n",
"llm = ragas_llm(provider=\"openai\",model=\"gpt-4o\",client=OpenAI())\n",
"llm.generate(\"What is the capital of India?\",response_model=Response) #works fine\n",
"\n",
"try:\n",
" await llm.agenerate(\"What is the capital of India?\", response_model=Response)\n",
"except TypeError as e:\n",
" assert isinstance(e, TypeError)\n",
"#gives TypeError: object Response can't be used in 'await' expression\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Response(response='The capital of India is New Delhi.')"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#| eval: false\n",
"\n",
"from openai import AsyncOpenAI\n",
"\n",
"llm = ragas_llm(provider=\"openai\",model=\"gpt-4o\",client=AsyncOpenAI())\n",
"await llm.agenerate(\"What is the capital of India?\",response_model=Response)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Response(response='The capital of India is New Delhi.')"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#| eval: false\n",
"\n",
"from anthropic import Anthropic\n",
"\n",
"llm = ragas_llm(provider=\"anthropic\",model=\"claude-3-opus-20240229\",client=Anthropic(),max_tokens=1024)\n",
"llm.generate(\"What is the capital of India?\",response_model=Response)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "python3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
19 changes: 15 additions & 4 deletions nbs/metric/base.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@
"from pydantic import BaseModel\n",
"import typing as t\n",
"from ragas_annotator.metric import MetricResult\n",
"from ragas_annotator.metric import LLM\n",
"from ragas_annotator.llm import RagasLLM\n",
"\n",
"@dataclass\n",
"class Metric(ABC):\n",
" \"\"\"Base class for all metrics in the LLM evaluation library.\"\"\"\n",
" name: str\n",
" prompt: str\n",
" llm: LLM\n",
" llm: RagasLLM\n",
" _response_models: t.Dict[bool, t.Type[BaseModel]] = field(\n",
" default_factory=dict, init=False, repr=False\n",
" )\n",
Expand Down Expand Up @@ -114,6 +114,11 @@
"source": [
"#| eval: false\n",
"\n",
"from ragas_annotator.llm import ragas_llm\n",
"from openai import OpenAI\n",
"\n",
"llm = ragas_llm(provider=\"openai\",model=\"gpt-4o\",client=OpenAI())\n",
"\n",
"@dataclass\n",
"class CustomMetric(Metric):\n",
" values: t.List[str] = field(default_factory=lambda: [\"pass\", \"fail\"])\n",
Expand All @@ -131,12 +136,18 @@
" \n",
" return results[0] # Placeholder for ensemble logic\n",
"\n",
"my_metric = CustomMetric(name=\"example\", prompt=\"What is the result of {input}?\", llm=LLM())\n",
"my_metric = CustomMetric(name=\"example\", prompt=\"What is the result of {input}?\", llm=llm)\n",
"my_metric.score(input=\"test\")"
]
}
],
"metadata": {},
"metadata": {
"kernelspec": {
"display_name": "python3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
11 changes: 8 additions & 3 deletions nbs/metric/decorator.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"import asyncio\n",
Copy link
Member

@jjmachan jjmachan Mar 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line #23.        def decorator_factory(llm:RagasLLM, prompt, name: t.Optional[str] = None, **metric_params):

type annotate prompt


Reply via ReviewNB

Copy link
Member

@jjmachan jjmachan Mar 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line #43.                class CustomMetric(metric_class):

its better to move this outside right?


Reply via ReviewNB

Copy link
Member

@jjmachan jjmachan Mar 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line #108.                # Preserve metadata

use @wrap from functools


Reply via ReviewNB

"from dataclasses import dataclass\n",
"from ragas_annotator.metric import MetricResult\n",
"from ragas_annotator.llm import RagasLLM\n",
"\n",
"\n",
"\n",
Expand All @@ -44,7 +45,7 @@
" Returns:\n",
" A decorator factory function for the specified metric type\n",
" \"\"\"\n",
" def decorator_factory(llm, prompt, name: t.Optional[str] = None, **metric_params):\n",
" def decorator_factory(llm:RagasLLM, prompt, name: t.Optional[str] = None, **metric_params):\n",
" \"\"\"\n",
" Creates a decorator that wraps a function into a metric instance.\n",
" \n",
Expand Down Expand Up @@ -168,12 +169,16 @@
"\n",
"\n",
"from ragas_annotator.metric import DiscreteMetric\n",
"from ragas_annotator.metric.llm import LLM\n",
"from pydantic import BaseModel\n",
"\n",
"from ragas_annotator.llm import ragas_llm\n",
"from openai import OpenAI\n",
"\n",
"llm = ragas_llm(provider=\"openai\",model=\"gpt-4o\",client=OpenAI())\n",
"\n",
"discrete_metric = create_metric_decorator(DiscreteMetric)\n",
"\n",
"@discrete_metric(llm=LLM(),\n",
"@discrete_metric(llm=llm,\n",
" prompt=\"Evaluate if given answer is helpful\\n\\n{response}\",\n",
" name='new_metric',values=[\"low\",\"med\",\"high\"])\n",
"def my_metric(llm,prompt,**kwargs):\n",
Expand Down
Loading