Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/whispering-brave-potoo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"stagehand": patch
---

Make litellm client async
2 changes: 1 addition & 1 deletion stagehand/handlers/extract_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ async def extract(
schema = transformed_schema = DefaultExtractSchema

# Use inference to call the LLM
extraction_response = extract_inference(
extraction_response = await extract_inference(
instruction=instruction,
tree_elements=output_string,
schema=transformed_schema,
Expand Down
2 changes: 1 addition & 1 deletion stagehand/handlers/observe_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ async def observe(
iframes = tree.get("iframes", [])

# use inference to call the llm
observation_response = observe_inference(
observation_response = await observe_inference(
instruction=instruction,
tree_elements=output_string,
llm_client=self.stagehand.llm,
Expand Down
14 changes: 7 additions & 7 deletions stagehand/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
f"Set global litellm.api_base to {value}", category="llm"
)

def create_response(
async def create_response(
self,
*,
messages: list[dict[str, str]],
Expand All @@ -77,7 +77,7 @@ def create_response(
Overrides the default_model if provided.
function_name: The name of the Stagehand function calling this method (ACT, OBSERVE, etc.)
Used for metrics tracking.
**kwargs: Additional parameters to pass directly to litellm.completion
**kwargs: Additional parameters to pass directly to litellm.acompletion
(e.g., temperature, max_tokens, stream=True, specific provider arguments).

Returns:
Expand All @@ -87,7 +87,7 @@ def create_response(

Raises:
ValueError: If no model is specified (neither default nor in the call).
Exception: Propagates exceptions from litellm.completion.
Exception: Propagates exceptions from litellm.acompletion.
"""
completion_model = model or self.default_model
if not completion_model:
Expand Down Expand Up @@ -115,16 +115,16 @@ def create_response(
filtered_params["temperature"] = 1

self.logger.debug(
f"Calling litellm.completion with model={completion_model} and params: {filtered_params}",
f"Calling litellm.acompletion with model={completion_model} and params: {filtered_params}",
category="llm",
)

try:
# Start tracking inference time
start_time = start_inference_timer()

# Use litellm's completion function
response = litellm.completion(**filtered_params)
# Use litellm's async completion function
response = await litellm.acompletion(**filtered_params)

# Calculate inference time
inference_time_ms = get_inference_time_ms(start_time)
Expand All @@ -136,6 +136,6 @@ def create_response(
return response

except Exception as e:
self.logger.error(f"Error calling litellm.completion: {e}", category="llm")
self.logger.error(f"Error calling litellm.acompletion: {e}", category="llm")
# Consider more specific exception handling based on litellm errors
raise
10 changes: 5 additions & 5 deletions stagehand/llm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


# TODO: kwargs
def observe(
async def observe(
instruction: str,
tree_elements: str,
llm_client: Any,
Expand Down Expand Up @@ -66,7 +66,7 @@ def observe(
try:
# Call the LLM
logger.info("Calling LLM")
response = llm_client.create_response(
response = await llm_client.create_response(
model=llm_client.default_model,
messages=messages,
response_format=ObserveInferenceSchema,
Expand Down Expand Up @@ -123,7 +123,7 @@ def observe(
}


def extract(
async def extract(
instruction: str,
tree_elements: str,
schema: Optional[Union[type[BaseModel], dict]] = None,
Expand Down Expand Up @@ -177,7 +177,7 @@ def extract(

# Call the LLM with appropriate parameters
try:
extract_response = llm_client.create_response(
extract_response = await llm_client.create_response(
model=llm_client.default_model,
messages=extract_messages,
response_format=response_format,
Expand Down Expand Up @@ -227,7 +227,7 @@ def extract(
# Call LLM for metadata
try:
metadata_start_time = time.time()
metadata_response = llm_client.create_response(
metadata_response = await llm_client.create_response(
model=llm_client.default_model,
messages=metadata_messages,
response_format=metadata_schema,
Expand Down
18 changes: 9 additions & 9 deletions tests/mocks/mock_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def get_usage_stats(self) -> Dict[str, int]:
"total_tokens": total_prompt_tokens + total_completion_tokens
}

def create_response(
async def create_response(
self,
*,
messages: list[dict[str, str]],
Expand All @@ -274,13 +274,13 @@ def create_response(
# Fall back to content-based detection
content = str(messages).lower()
response_type = self._determine_response_type(content)

# Track the call
self.call_count += 1
self.last_messages = messages
self.last_model = model or self.default_model
self.last_kwargs = kwargs

# Store call in history
call_info = {
"messages": messages,
Expand All @@ -290,26 +290,26 @@ def create_response(
"timestamp": asyncio.get_event_loop().time()
}
self.call_history.append(call_info)

# Simulate failure if configured
if self.should_fail:
raise Exception(self.failure_message)

# Check for custom responses first
if response_type in self.custom_responses:
response_data = self.custom_responses[response_type]
if callable(response_data):
response_data = response_data(messages, **kwargs)
return self._create_response(response_data, model=self.last_model)

# Use default response mapping
response_generator = self.response_mapping.get(response_type, self._default_response)
response_data = response_generator(messages, **kwargs)

response = self._create_response(response_data, model=self.last_model)

# Call metrics callback if set
if self.metrics_callback:
self.metrics_callback(response, 100, response_type) # 100ms mock inference time

return response
Loading