|
20 | 20 | "cell_type": "code", |
21 | 21 | "execution_count": null, |
22 | 22 | "metadata": {}, |
23 | | - "outputs": [ |
24 | | - { |
25 | | - "name": "stderr", |
26 | | - "output_type": "stream", |
27 | | - "text": [ |
28 | | - "/opt/homebrew/Caskroom/miniforge/base/envs/random/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", |
29 | | - " from .autonotebook import tqdm as notebook_tqdm\n" |
30 | | - ] |
31 | | - } |
32 | | - ], |
| 23 | + "outputs": [], |
33 | 24 | "source": [ |
34 | 25 | "#| export\n", |
35 | 26 | "\n", |
|
42 | 33 | "T = t.TypeVar('T', bound=BaseModel)\n", |
43 | 34 | "\n", |
44 | 35 | "class RagasLLM:\n", |
45 | | - " def __init__(self, provider: str, model:str, client: t.Any):\n", |
| 36 | + " def __init__(self, provider: str, model:str, client: t.Any, **model_args):\n", |
46 | 37 | " self.provider = provider.lower()\n", |
47 | 38 | " self.model = model\n", |
| 39 | + " self.model_args = model_args or {}\n", |
48 | 40 | " self.client = self._initialize_client(provider, client)\n", |
49 | 41 | " # Check if client is async-capable at initialization\n", |
50 | 42 | " self.is_async = self._check_client_async()\n", |
|
98 | 90 | " loop.close()\n", |
99 | 91 | " asyncio.set_event_loop(None)\n", |
100 | 92 | " \n", |
101 | | - " def generate(self, prompt: str, response_model: t.Type[T], **kwargs) -> T:\n", |
| 93 | + " def generate(self, prompt: str, response_model: t.Type[T]) -> T:\n", |
102 | 94 | " \"\"\"Generate a response using the configured LLM.\n", |
103 | 95 | " \n", |
104 | 96 | " For async clients, this will run the async method in the appropriate event loop.\n", |
105 | 97 | " \"\"\"\n", |
106 | 98 | " messages = [{\"role\": \"user\", \"content\": prompt}]\n", |
107 | | - " if \"model\" not in kwargs and self.model:\n", |
108 | | - " kwargs[\"model\"] = self.model\n", |
109 | 99 | " \n", |
110 | 100 | " # If client is async, use the appropriate method to run it\n", |
111 | 101 | " if self.is_async:\n", |
112 | 102 | " return self._run_async_in_current_loop(\n", |
113 | | - " self.agenerate(prompt, response_model, **kwargs)\n", |
| 103 | + " self.agenerate(prompt, response_model)\n", |
114 | 104 | " )\n", |
115 | 105 | " else:\n", |
116 | 106 | " # Regular sync client, just call the method directly\n", |
117 | 107 | " return self.client.chat.completions.create(\n", |
| 108 | + " model=self.model,\n", |
118 | 109 | " messages=messages,\n", |
119 | 110 | " response_model=response_model,\n", |
120 | | - " **kwargs\n", |
| 111 | + " **self.model_args,\n", |
121 | 112 | " )\n", |
122 | 113 | " \n", |
123 | | - " async def agenerate(self, prompt: str, response_model: t.Type[T], **kwargs) -> T:\n", |
| 114 | + " async def agenerate(self, prompt: str, response_model: t.Type[T]) -> T:\n", |
124 | 115 | " \"\"\"Asynchronously generate a response using the configured LLM.\"\"\"\n", |
125 | 116 | " messages = [{\"role\": \"user\", \"content\": prompt}]\n", |
126 | | - " if \"model\" not in kwargs and self.model:\n", |
127 | | - " kwargs[\"model\"] = self.model\n", |
128 | 117 | " \n", |
129 | 118 | " # If client is not async, raise a helpful error\n", |
130 | 119 | " if not self.is_async:\n", |
|
134 | 123 | " \n", |
135 | 124 | " # Regular async client, call the method directly\n", |
136 | 125 | " return await self.client.chat.completions.create(\n", |
| 126 | + " model=self.model,\n", |
137 | 127 | " messages=messages,\n", |
138 | 128 | " response_model=response_model,\n", |
139 | | - " **kwargs\n", |
| 129 | + " **self.model_args,\n", |
140 | 130 | " )\n", |
141 | 131 | "\n", |
142 | | - "def ragas_llm(provider: str,model:str, client: t.Any,) -> RagasLLM:\n", |
143 | | - " return RagasLLM(provider=provider, client=client, model=model)" |
| 132 | + "def ragas_llm(provider: str,model:str, client: t.Any, **model_args) -> RagasLLM:\n", |
| 133 | + " return RagasLLM(provider=provider, client=client, model=model, **model_args)" |
144 | 134 | ] |
145 | 135 | }, |
146 | 136 | { |
|
201 | 191 | "cell_type": "code", |
202 | 192 | "execution_count": null, |
203 | 193 | "metadata": {}, |
204 | | - "outputs": [], |
205 | | - "source": [] |
| 194 | + "outputs": [ |
| 195 | + { |
| 196 | + "data": { |
| 197 | + "text/plain": [ |
| 198 | + "Response(response='The capital of India is New Delhi.')" |
| 199 | + ] |
| 200 | + }, |
| 201 | + "execution_count": null, |
| 202 | + "metadata": {}, |
| 203 | + "output_type": "execute_result" |
| 204 | + } |
| 205 | + ], |
| 206 | + "source": [ |
| 207 | + "#| eval: false\n", |
| 208 | + "\n", |
| 209 | + "from anthropic import Anthropic\n", |
| 210 | + "\n", |
| 211 | + "llm = ragas_llm(provider=\"anthropic\",model=\"claude-3-opus-20240229\",client=Anthropic(),max_tokens=1024)\n", |
| 212 | + "llm.generate(\"What is the capital of India?\",response_model=Response)" |
| 213 | + ] |
206 | 214 | } |
207 | 215 | ], |
208 | 216 | "metadata": { |
|
0 commit comments