Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/whispering-brave-potoo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"stagehand": patch
---

Make litellm client async
2 changes: 1 addition & 1 deletion stagehand/handlers/act_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ async def act(self, options: Union[ActOptions, ObserveResult]) -> ActResult:
variables = options.get("variables", {})
element_to_act_on.arguments = [
str(arg).replace(f"%{key}%", str(value))
for arg in (element_to_act_on.arguments or [])
for arg in element_to_act_on.arguments or []
for key, value in variables.items()
]

Expand Down
2 changes: 1 addition & 1 deletion stagehand/handlers/extract_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ async def extract(
schema = transformed_schema = DefaultExtractSchema

# Use inference to call the LLM
extraction_response = extract_inference(
extraction_response = await extract_inference(
instruction=instruction,
tree_elements=output_string,
schema=transformed_schema,
Expand Down
2 changes: 1 addition & 1 deletion stagehand/handlers/observe_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ async def observe(
iframes = tree.get("iframes", [])

# use inference to call the llm
observation_response = observe_inference(
observation_response = await observe_inference(
instruction=instruction,
tree_elements=output_string,
llm_client=self.stagehand.llm,
Expand Down
14 changes: 7 additions & 7 deletions stagehand/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
f"Set global litellm.api_base to {value}", category="llm"
)

def create_response(
async def create_response(
self,
*,
messages: list[dict[str, str]],
Expand All @@ -77,7 +77,7 @@ def create_response(
Overrides the default_model if provided.
function_name: The name of the Stagehand function calling this method (ACT, OBSERVE, etc.)
Used for metrics tracking.
**kwargs: Additional parameters to pass directly to litellm.completion
**kwargs: Additional parameters to pass directly to litellm.acompletion
(e.g., temperature, max_tokens, stream=True, specific provider arguments).

Returns:
Expand All @@ -87,7 +87,7 @@ def create_response(

Raises:
ValueError: If no model is specified (neither default nor in the call).
Exception: Propagates exceptions from litellm.completion.
Exception: Propagates exceptions from litellm.acompletion.
"""
completion_model = model or self.default_model
if not completion_model:
Expand Down Expand Up @@ -115,16 +115,16 @@ def create_response(
filtered_params["temperature"] = 1

self.logger.debug(
f"Calling litellm.completion with model={completion_model} and params: {filtered_params}",
f"Calling litellm.acompletion with model={completion_model} and params: {filtered_params}",
category="llm",
)

try:
# Start tracking inference time
start_time = start_inference_timer()

# Use litellm's completion function
response = litellm.completion(**filtered_params)
# Use litellm's async completion function
response = await litellm.acompletion(**filtered_params)

# Calculate inference time
inference_time_ms = get_inference_time_ms(start_time)
Expand All @@ -136,6 +136,6 @@ def create_response(
return response

except Exception as e:
self.logger.error(f"Error calling litellm.completion: {e}", category="llm")
self.logger.error(f"Error calling litellm.acompletion: {e}", category="llm")
# Consider more specific exception handling based on litellm errors
raise
10 changes: 5 additions & 5 deletions stagehand/llm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


# TODO: kwargs
def observe(
async def observe(
instruction: str,
tree_elements: str,
llm_client: Any,
Expand Down Expand Up @@ -66,7 +66,7 @@ def observe(
try:
# Call the LLM
logger.info("Calling LLM")
response = llm_client.create_response(
response = await llm_client.create_response(
model=llm_client.default_model,
messages=messages,
response_format=ObserveInferenceSchema,
Expand Down Expand Up @@ -123,7 +123,7 @@ def observe(
}


def extract(
async def extract(
instruction: str,
tree_elements: str,
schema: Optional[Union[type[BaseModel], dict]] = None,
Expand Down Expand Up @@ -177,7 +177,7 @@ def extract(

# Call the LLM with appropriate parameters
try:
extract_response = llm_client.create_response(
extract_response = await llm_client.create_response(
model=llm_client.default_model,
messages=extract_messages,
response_format=response_format,
Expand Down Expand Up @@ -227,7 +227,7 @@ def extract(
# Call LLM for metadata
try:
metadata_start_time = time.time()
metadata_response = llm_client.create_response(
metadata_response = await llm_client.create_response(
model=llm_client.default_model,
messages=metadata_messages,
response_format=metadata_schema,
Expand Down
18 changes: 9 additions & 9 deletions tests/mocks/mock_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def get_usage_stats(self) -> Dict[str, int]:
"total_tokens": total_prompt_tokens + total_completion_tokens
}

def create_response(
async def create_response(
self,
*,
messages: list[dict[str, str]],
Expand All @@ -274,13 +274,13 @@ def create_response(
# Fall back to content-based detection
content = str(messages).lower()
response_type = self._determine_response_type(content)

# Track the call
self.call_count += 1
self.last_messages = messages
self.last_model = model or self.default_model
self.last_kwargs = kwargs

# Store call in history
call_info = {
"messages": messages,
Expand All @@ -290,26 +290,26 @@ def create_response(
"timestamp": asyncio.get_event_loop().time()
}
self.call_history.append(call_info)

# Simulate failure if configured
if self.should_fail:
raise Exception(self.failure_message)

# Check for custom responses first
if response_type in self.custom_responses:
response_data = self.custom_responses[response_type]
if callable(response_data):
response_data = response_data(messages, **kwargs)
return self._create_response(response_data, model=self.last_model)

# Use default response mapping
response_generator = self.response_mapping.get(response_type, self._default_response)
response_data = response_generator(messages, **kwargs)

response = self._create_response(response_data, model=self.last_model)

# Call metrics callback if set
if self.metrics_callback:
self.metrics_callback(response, 100, response_type) # 100ms mock inference time

return response