Skip to content

Commit 3bcdd05

Browse files
authored
make litellm async (#205)
* make litellm async * add changeset * remove extra file * format
1 parent ad95605 commit 3bcdd05

File tree

7 files changed

+29
-24
lines changed

7 files changed

+29
-24
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"stagehand": patch
3+
---
4+
5+
Make litellm client async

stagehand/handlers/act_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ async def act(self, options: Union[ActOptions, ObserveResult]) -> ActResult:
9797
variables = options.get("variables", {})
9898
element_to_act_on.arguments = [
9999
str(arg).replace(f"%{key}%", str(value))
100-
for arg in (element_to_act_on.arguments or [])
100+
for arg in element_to_act_on.arguments or []
101101
for key, value in variables.items()
102102
]
103103

stagehand/handlers/extract_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ async def extract(
105105
schema = transformed_schema = DefaultExtractSchema
106106

107107
# Use inference to call the LLM
108-
extraction_response = extract_inference(
108+
extraction_response = await extract_inference(
109109
instruction=instruction,
110110
tree_elements=output_string,
111111
schema=transformed_schema,

stagehand/handlers/observe_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ async def observe(
7474
iframes = tree.get("iframes", [])
7575

7676
# use inference to call the llm
77-
observation_response = observe_inference(
77+
observation_response = await observe_inference(
7878
instruction=instruction,
7979
tree_elements=output_string,
8080
llm_client=self.stagehand.llm,

stagehand/llm/client.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(
6060
f"Set global litellm.api_base to {value}", category="llm"
6161
)
6262

63-
def create_response(
63+
async def create_response(
6464
self,
6565
*,
6666
messages: list[dict[str, str]],
@@ -77,7 +77,7 @@ def create_response(
7777
Overrides the default_model if provided.
7878
function_name: The name of the Stagehand function calling this method (ACT, OBSERVE, etc.)
7979
Used for metrics tracking.
80-
**kwargs: Additional parameters to pass directly to litellm.completion
80+
**kwargs: Additional parameters to pass directly to litellm.acompletion
8181
(e.g., temperature, max_tokens, stream=True, specific provider arguments).
8282
8383
Returns:
@@ -87,7 +87,7 @@ def create_response(
8787
8888
Raises:
8989
ValueError: If no model is specified (neither default nor in the call).
90-
Exception: Propagates exceptions from litellm.completion.
90+
Exception: Propagates exceptions from litellm.acompletion.
9191
"""
9292
completion_model = model or self.default_model
9393
if not completion_model:
@@ -115,16 +115,16 @@ def create_response(
115115
filtered_params["temperature"] = 1
116116

117117
self.logger.debug(
118-
f"Calling litellm.completion with model={completion_model} and params: {filtered_params}",
118+
f"Calling litellm.acompletion with model={completion_model} and params: {filtered_params}",
119119
category="llm",
120120
)
121121

122122
try:
123123
# Start tracking inference time
124124
start_time = start_inference_timer()
125125

126-
# Use litellm's completion function
127-
response = litellm.completion(**filtered_params)
126+
# Use litellm's async completion function
127+
response = await litellm.acompletion(**filtered_params)
128128

129129
# Calculate inference time
130130
inference_time_ms = get_inference_time_ms(start_time)
@@ -136,6 +136,6 @@ def create_response(
136136
return response
137137

138138
except Exception as e:
139-
self.logger.error(f"Error calling litellm.completion: {e}", category="llm")
139+
self.logger.error(f"Error calling litellm.acompletion: {e}", category="llm")
140140
# Consider more specific exception handling based on litellm errors
141141
raise

stagehand/llm/inference.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222

2323
# TODO: kwargs
24-
def observe(
24+
async def observe(
2525
instruction: str,
2626
tree_elements: str,
2727
llm_client: Any,
@@ -66,7 +66,7 @@ def observe(
6666
try:
6767
# Call the LLM
6868
logger.info("Calling LLM")
69-
response = llm_client.create_response(
69+
response = await llm_client.create_response(
7070
model=llm_client.default_model,
7171
messages=messages,
7272
response_format=ObserveInferenceSchema,
@@ -123,7 +123,7 @@ def observe(
123123
}
124124

125125

126-
def extract(
126+
async def extract(
127127
instruction: str,
128128
tree_elements: str,
129129
schema: Optional[Union[type[BaseModel], dict]] = None,
@@ -177,7 +177,7 @@ def extract(
177177

178178
# Call the LLM with appropriate parameters
179179
try:
180-
extract_response = llm_client.create_response(
180+
extract_response = await llm_client.create_response(
181181
model=llm_client.default_model,
182182
messages=extract_messages,
183183
response_format=response_format,
@@ -227,7 +227,7 @@ def extract(
227227
# Call LLM for metadata
228228
try:
229229
metadata_start_time = time.time()
230-
metadata_response = llm_client.create_response(
230+
metadata_response = await llm_client.create_response(
231231
model=llm_client.default_model,
232232
messages=metadata_messages,
233233
response_format=metadata_schema,

tests/mocks/mock_llm.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def get_usage_stats(self) -> Dict[str, int]:
258258
"total_tokens": total_prompt_tokens + total_completion_tokens
259259
}
260260

261-
def create_response(
261+
async def create_response(
262262
self,
263263
*,
264264
messages: list[dict[str, str]],
@@ -274,13 +274,13 @@ def create_response(
274274
# Fall back to content-based detection
275275
content = str(messages).lower()
276276
response_type = self._determine_response_type(content)
277-
277+
278278
# Track the call
279279
self.call_count += 1
280280
self.last_messages = messages
281281
self.last_model = model or self.default_model
282282
self.last_kwargs = kwargs
283-
283+
284284
# Store call in history
285285
call_info = {
286286
"messages": messages,
@@ -290,26 +290,26 @@ def create_response(
290290
"timestamp": asyncio.get_event_loop().time()
291291
}
292292
self.call_history.append(call_info)
293-
293+
294294
# Simulate failure if configured
295295
if self.should_fail:
296296
raise Exception(self.failure_message)
297-
297+
298298
# Check for custom responses first
299299
if response_type in self.custom_responses:
300300
response_data = self.custom_responses[response_type]
301301
if callable(response_data):
302302
response_data = response_data(messages, **kwargs)
303303
return self._create_response(response_data, model=self.last_model)
304-
304+
305305
# Use default response mapping
306306
response_generator = self.response_mapping.get(response_type, self._default_response)
307307
response_data = response_generator(messages, **kwargs)
308-
308+
309309
response = self._create_response(response_data, model=self.last_model)
310-
310+
311311
# Call metrics callback if set
312312
if self.metrics_callback:
313313
self.metrics_callback(response, 100, response_type) # 100ms mock inference time
314-
314+
315315
return response

0 commit comments

Comments
 (0)