Skip to content

Commit e414984

Browse files
committed
enable pass keyword args
1 parent 77322ec commit e414984

File tree

2 files changed

+43
-36
lines changed

2 files changed

+43
-36
lines changed

nbs/llm/llm.ipynb

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,7 @@
2020
"cell_type": "code",
2121
"execution_count": null,
2222
"metadata": {},
23-
"outputs": [
24-
{
25-
"name": "stderr",
26-
"output_type": "stream",
27-
"text": [
28-
"/opt/homebrew/Caskroom/miniforge/base/envs/random/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
29-
" from .autonotebook import tqdm as notebook_tqdm\n"
30-
]
31-
}
32-
],
23+
"outputs": [],
3324
"source": [
3425
"#| export\n",
3526
"\n",
@@ -42,9 +33,10 @@
4233
"T = t.TypeVar('T', bound=BaseModel)\n",
4334
"\n",
4435
"class RagasLLM:\n",
45-
" def __init__(self, provider: str, model:str, client: t.Any):\n",
36+
" def __init__(self, provider: str, model:str, client: t.Any, **model_args):\n",
4637
" self.provider = provider.lower()\n",
4738
" self.model = model\n",
39+
" self.model_args = model_args or {}\n",
4840
" self.client = self._initialize_client(provider, client)\n",
4941
" # Check if client is async-capable at initialization\n",
5042
" self.is_async = self._check_client_async()\n",
@@ -98,33 +90,30 @@
9890
" loop.close()\n",
9991
" asyncio.set_event_loop(None)\n",
10092
" \n",
101-
" def generate(self, prompt: str, response_model: t.Type[T], **kwargs) -> T:\n",
93+
" def generate(self, prompt: str, response_model: t.Type[T]) -> T:\n",
10294
" \"\"\"Generate a response using the configured LLM.\n",
10395
" \n",
10496
" For async clients, this will run the async method in the appropriate event loop.\n",
10597
" \"\"\"\n",
10698
" messages = [{\"role\": \"user\", \"content\": prompt}]\n",
107-
" if \"model\" not in kwargs and self.model:\n",
108-
" kwargs[\"model\"] = self.model\n",
10999
" \n",
110100
" # If client is async, use the appropriate method to run it\n",
111101
" if self.is_async:\n",
112102
" return self._run_async_in_current_loop(\n",
113-
" self.agenerate(prompt, response_model, **kwargs)\n",
103+
" self.agenerate(prompt, response_model)\n",
114104
" )\n",
115105
" else:\n",
116106
" # Regular sync client, just call the method directly\n",
117107
" return self.client.chat.completions.create(\n",
108+
" model=self.model,\n",
118109
" messages=messages,\n",
119110
" response_model=response_model,\n",
120-
" **kwargs\n",
111+
" **self.model_args,\n",
121112
" )\n",
122113
" \n",
123-
" async def agenerate(self, prompt: str, response_model: t.Type[T], **kwargs) -> T:\n",
114+
" async def agenerate(self, prompt: str, response_model: t.Type[T]) -> T:\n",
124115
" \"\"\"Asynchronously generate a response using the configured LLM.\"\"\"\n",
125116
" messages = [{\"role\": \"user\", \"content\": prompt}]\n",
126-
" if \"model\" not in kwargs and self.model:\n",
127-
" kwargs[\"model\"] = self.model\n",
128117
" \n",
129118
" # If client is not async, raise a helpful error\n",
130119
" if not self.is_async:\n",
@@ -134,13 +123,14 @@
134123
" \n",
135124
" # Regular async client, call the method directly\n",
136125
" return await self.client.chat.completions.create(\n",
126+
" model=self.model,\n",
137127
" messages=messages,\n",
138128
" response_model=response_model,\n",
139-
" **kwargs\n",
129+
" **self.model_args,\n",
140130
" )\n",
141131
"\n",
142-
"def ragas_llm(provider: str,model:str, client: t.Any,) -> RagasLLM:\n",
143-
" return RagasLLM(provider=provider, client=client, model=model)"
132+
"def ragas_llm(provider: str,model:str, client: t.Any, **model_args) -> RagasLLM:\n",
133+
" return RagasLLM(provider=provider, client=client, model=model, **model_args)"
144134
]
145135
},
146136
{
@@ -201,8 +191,26 @@
201191
"cell_type": "code",
202192
"execution_count": null,
203193
"metadata": {},
204-
"outputs": [],
205-
"source": []
194+
"outputs": [
195+
{
196+
"data": {
197+
"text/plain": [
198+
"Response(response='The capital of India is New Delhi.')"
199+
]
200+
},
201+
"execution_count": null,
202+
"metadata": {},
203+
"output_type": "execute_result"
204+
}
205+
],
206+
"source": [
207+
"#| eval: false\n",
208+
"\n",
209+
"from anthropic import Anthropic\n",
210+
"\n",
211+
"llm = ragas_llm(provider=\"anthropic\",model=\"claude-3-opus-20240229\",client=Anthropic(),max_tokens=1024)\n",
212+
"llm.generate(\"What is the capital of India?\",response_model=Response)"
213+
]
206214
}
207215
],
208216
"metadata": {

ragas_annotator/llm/llm.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
T = t.TypeVar('T', bound=BaseModel)
1414

1515
class RagasLLM:
16-
def __init__(self, provider: str, model:str, client: t.Any):
16+
def __init__(self, provider: str, model:str, client: t.Any, **model_args):
1717
self.provider = provider.lower()
1818
self.model = model
19+
self.model_args = model_args or {}
1920
self.client = self._initialize_client(provider, client)
2021
# Check if client is async-capable at initialization
2122
self.is_async = self._check_client_async()
@@ -69,33 +70,30 @@ def _run_async_in_current_loop(self, coro):
6970
loop.close()
7071
asyncio.set_event_loop(None)
7172

72-
def generate(self, prompt: str, response_model: t.Type[T], **kwargs) -> T:
73+
def generate(self, prompt: str, response_model: t.Type[T]) -> T:
7374
"""Generate a response using the configured LLM.
7475
7576
For async clients, this will run the async method in the appropriate event loop.
7677
"""
7778
messages = [{"role": "user", "content": prompt}]
78-
if "model" not in kwargs and self.model:
79-
kwargs["model"] = self.model
8079

8180
# If client is async, use the appropriate method to run it
8281
if self.is_async:
8382
return self._run_async_in_current_loop(
84-
self.agenerate(prompt, response_model, **kwargs)
83+
self.agenerate(prompt, response_model)
8584
)
8685
else:
8786
# Regular sync client, just call the method directly
8887
return self.client.chat.completions.create(
88+
model=self.model,
8989
messages=messages,
9090
response_model=response_model,
91-
**kwargs
91+
**self.model_args,
9292
)
9393

94-
async def agenerate(self, prompt: str, response_model: t.Type[T], **kwargs) -> T:
94+
async def agenerate(self, prompt: str, response_model: t.Type[T]) -> T:
9595
"""Asynchronously generate a response using the configured LLM."""
9696
messages = [{"role": "user", "content": prompt}]
97-
if "model" not in kwargs and self.model:
98-
kwargs["model"] = self.model
9997

10098
# If client is not async, raise a helpful error
10199
if not self.is_async:
@@ -105,10 +103,11 @@ async def agenerate(self, prompt: str, response_model: t.Type[T], **kwargs) -> T
105103

106104
# Regular async client, call the method directly
107105
return await self.client.chat.completions.create(
106+
model=self.model,
108107
messages=messages,
109108
response_model=response_model,
110-
**kwargs
109+
**self.model_args,
111110
)
112111

113-
def ragas_llm(provider: str,model:str, client: t.Any,) -> RagasLLM:
114-
return RagasLLM(provider=provider, client=client, model=model)
112+
def ragas_llm(provider: str,model:str, client: t.Any, **model_args) -> RagasLLM:
113+
return RagasLLM(provider=provider, client=client, model=model, **model_args)

0 commit comments

Comments
 (0)