-
Notifications
You must be signed in to change notification settings - Fork 1
feat: ragas llm #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
955a412
ce290c2
0eb5002
a164263
b47613c
40ebbcc
788fc9f
fba8d9f
77322ec
e414984
2100eac
cfe0c73
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,217 @@ | ||
| { | ||
| "cells": [ | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "#| default_exp llm.llm" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "# LLM Interface for Ragas" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [ | ||
| { | ||
| "name": "stderr", | ||
| "output_type": "stream", | ||
| "text": [ | ||
| "/opt/homebrew/Caskroom/miniforge/base/envs/random/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | ||
| " from .autonotebook import tqdm as notebook_tqdm\n" | ||
| ] | ||
| } | ||
| ], | ||
| "source": [ | ||
| "#| export\n", | ||
| "\n", | ||
| "import typing as t\n", | ||
| "import asyncio\n", | ||
| "import inspect\n", | ||
| "from pydantic import BaseModel\n", | ||
| "import instructor\n", | ||
| "\n", | ||
| "T = t.TypeVar('T', bound=BaseModel)\n", | ||
| "\n", | ||
| "class RagasLLM:\n", | ||
| " def __init__(self, provider: str, model:str, client: t.Any):\n", | ||
| " self.provider = provider.lower()\n", | ||
| " self.model = model\n", | ||
| " self.client = self._initialize_client(provider, client)\n", | ||
| " # Check if client is async-capable at initialization\n", | ||
| " self.is_async = self._check_client_async()\n", | ||
| " \n", | ||
| " def _check_client_async(self) -> bool:\n", | ||
| " \"\"\"Determine if the client is async-capable.\"\"\"\n", | ||
| " try:\n", | ||
| " # Check if this is an async client by checking for a coroutine method\n", | ||
| " if hasattr(self.client.chat.completions, 'create'):\n", | ||
| " return inspect.iscoroutinefunction(self.client.chat.completions.create)\n", | ||
| " return False\n", | ||
| " except (AttributeError, TypeError):\n", | ||
| " return False\n", | ||
| " \n", | ||
| " def _initialize_client(self, provider: str, client: t.Any) -> t.Any:\n", | ||
| " provider = provider.lower()\n", | ||
| " \n", | ||
| " if provider == \"openai\":\n", | ||
| " return instructor.from_openai(client)\n", | ||
| " elif provider == \"anthropic\":\n", | ||
| " return instructor.from_anthropic(client)\n", | ||
| " elif provider == \"cohere\":\n", | ||
| " return instructor.from_cohere(client)\n", | ||
| " elif provider == \"gemini\":\n", | ||
| " return instructor.from_gemini(client)\n", | ||
| " elif provider == \"litellm\":\n", | ||
| " return instructor.from_litellm(client)\n", | ||
| " else:\n", | ||
| " raise ValueError(f\"Unsupported provider: {provider}\")\n", | ||
| " \n", | ||
| " def _run_async_in_current_loop(self, coro):\n", | ||
| " \"\"\"Run an async coroutine in the current event loop if possible.\n", | ||
| " \n", | ||
| " This handles Jupyter environments correctly by using the existing loop.\n", | ||
| " \"\"\"\n", | ||
| " try:\n", | ||
| " # Check if we're in an environment with an existing event loop (like Jupyter)\n", | ||
| " loop = asyncio.get_event_loop()\n", | ||
| " if loop.is_running():\n", | ||
| " # We're likely in a Jupyter environment\n", | ||
| " import nest_asyncio\n", | ||
| " nest_asyncio.apply()\n", | ||
| " return loop.run_until_complete(coro)\n", | ||
| " except RuntimeError:\n", | ||
| " # If we get a runtime error about no event loop, create a new one\n", | ||
| " loop = asyncio.new_event_loop()\n", | ||
| " asyncio.set_event_loop(loop)\n", | ||
| " try:\n", | ||
| " return loop.run_until_complete(coro)\n", | ||
| " finally:\n", | ||
| " loop.close()\n", | ||
| " asyncio.set_event_loop(None)\n", | ||
| " \n", | ||
| " def generate(self, prompt: str, response_model: t.Type[T], **kwargs) -> T:\n", | ||
| " \"\"\"Generate a response using the configured LLM.\n", | ||
| " \n", | ||
| " For async clients, this will run the async method in the appropriate event loop.\n", | ||
| " \"\"\"\n", | ||
| " messages = [{\"role\": \"user\", \"content\": prompt}]\n", | ||
| " if \"model\" not in kwargs and self.model:\n", | ||
| " kwargs[\"model\"] = self.model\n", | ||
| " \n", | ||
| " # If client is async, use the appropriate method to run it\n", | ||
| " if self.is_async:\n", | ||
| " return self._run_async_in_current_loop(\n", | ||
| " self.agenerate(prompt, response_model, **kwargs)\n", | ||
| " )\n", | ||
| " else:\n", | ||
| " # Regular sync client, just call the method directly\n", | ||
| " return self.client.chat.completions.create(\n", | ||
| " messages=messages,\n", | ||
| " response_model=response_model,\n", | ||
| " **kwargs\n", | ||
| " )\n", | ||
| " \n", | ||
| " async def agenerate(self, prompt: str, response_model: t.Type[T], **kwargs) -> T:\n", | ||
| " \"\"\"Asynchronously generate a response using the configured LLM.\"\"\"\n", | ||
| " messages = [{\"role\": \"user\", \"content\": prompt}]\n", | ||
| " if \"model\" not in kwargs and self.model:\n", | ||
| " kwargs[\"model\"] = self.model\n", | ||
| " \n", | ||
| " # If client is not async, raise a helpful error\n", | ||
| " if not self.is_async:\n", | ||
| " raise TypeError(\n", | ||
| " \"Cannot use agenerate() with a synchronous client. Use generate() instead.\"\n", | ||
| " )\n", | ||
| " \n", | ||
| " # Regular async client, call the method directly\n", | ||
| " return await self.client.chat.completions.create(\n", | ||
| " messages=messages,\n", | ||
| " response_model=response_model,\n", | ||
| " **kwargs\n", | ||
| " )\n", | ||
| "\n", | ||
| "def ragas_llm(provider: str,model:str, client: t.Any,) -> RagasLLM:\n", | ||
| " return RagasLLM(provider=provider, client=client, model=model)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "### Example Usage" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "#| eval: false\n", | ||
| "\n", | ||
| "from openai import OpenAI\n", | ||
| "class Response(BaseModel):\n", | ||
| " response: str\n", | ||
| "\n", | ||
| "llm = ragas_llm(provider=\"openai\",model=\"gpt-4o\",client=OpenAI())\n", | ||
| "llm.generate(\"What is the capital of India?\",response_model=Response) #works fine\n", | ||
| "\n", | ||
| "try:\n", | ||
| " await llm.agenerate(\"What is the capital of India?\", response_model=Response)\n", | ||
| "except TypeError as e:\n", | ||
| " assert isinstance(e, TypeError)\n", | ||
| "#gives TypeError: object Response can't be used in 'await' expression\n" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [ | ||
| { | ||
| "data": { | ||
| "text/plain": [ | ||
| "Response(response='The capital of India is New Delhi.')" | ||
| ] | ||
| }, | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "output_type": "execute_result" | ||
| } | ||
| ], | ||
| "source": [ | ||
| "#| eval: false\n", | ||
| "\n", | ||
| "from openai import AsyncOpenAI\n", | ||
| "\n", | ||
| "llm = ragas_llm(provider=\"openai\",model=\"gpt-4o\",client=AsyncOpenAI())\n", | ||
| "await llm.agenerate(\"What is the capital of India?\",response_model=Response)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [] | ||
| } | ||
| ], | ||
| "metadata": { | ||
| "kernelspec": { | ||
| "display_name": "python3", | ||
| "language": "python", | ||
| "name": "python3" | ||
| } | ||
| }, | ||
| "nbformat": 4, | ||
| "nbformat_minor": 2 | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,6 +30,7 @@ | |
| "import asyncio\n", | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Line #23. def decorator_factory(llm:RagasLLM, prompt, name: t.Optional[str] = None, **metric_params): type annotate Reply via ReviewNB
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Line #43. class CustomMetric(metric_class): its better to move this outside right? Reply via ReviewNB
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| "from dataclasses import dataclass\n", | ||
| "from ragas_annotator.metric import MetricResult\n", | ||
| "from ragas_annotator.llm import RagasLLM\n", | ||
| "\n", | ||
| "\n", | ||
| "\n", | ||
|
|
@@ -44,7 +45,7 @@ | |
| " Returns:\n", | ||
| " A decorator factory function for the specified metric type\n", | ||
| " \"\"\"\n", | ||
| " def decorator_factory(llm, prompt, name: t.Optional[str] = None, **metric_params):\n", | ||
| " def decorator_factory(llm:RagasLLM, prompt, name: t.Optional[str] = None, **metric_params):\n", | ||
| " \"\"\"\n", | ||
| " Creates a decorator that wraps a function into a metric instance.\n", | ||
| " \n", | ||
|
|
@@ -168,12 +169,16 @@ | |
| "\n", | ||
| "\n", | ||
| "from ragas_annotator.metric import DiscreteMetric\n", | ||
| "from ragas_annotator.metric.llm import LLM\n", | ||
| "from pydantic import BaseModel\n", | ||
| "\n", | ||
| "from ragas_annotator.llm import ragas_llm\n", | ||
| "from openai import OpenAI\n", | ||
| "\n", | ||
| "llm = ragas_llm(provider=\"openai\",model=\"gpt-4o\",client=OpenAI())\n", | ||
| "\n", | ||
| "discrete_metric = create_metric_decorator(DiscreteMetric)\n", | ||
| "\n", | ||
| "@discrete_metric(llm=LLM(),\n", | ||
| "@discrete_metric(llm=llm,\n", | ||
| " prompt=\"Evaluate if given answer is helpful\\n\\n{response}\",\n", | ||
| " name='new_metric',values=[\"low\",\"med\",\"high\"])\n", | ||
| "def my_metric(llm,prompt,**kwargs):\n", | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.