diff --git a/.changeset/whispering-brave-potoo.md b/.changeset/whispering-brave-potoo.md new file mode 100644 index 00000000..e7b37afd --- /dev/null +++ b/.changeset/whispering-brave-potoo.md @@ -0,0 +1,5 @@ +--- +"stagehand": patch +--- + +Make litellm client async diff --git a/stagehand/handlers/act_handler.py b/stagehand/handlers/act_handler.py index 27ae6bd2..d6a7ccac 100644 --- a/stagehand/handlers/act_handler.py +++ b/stagehand/handlers/act_handler.py @@ -97,7 +97,7 @@ async def act(self, options: Union[ActOptions, ObserveResult]) -> ActResult: variables = options.get("variables", {}) element_to_act_on.arguments = [ str(arg).replace(f"%{key}%", str(value)) - for arg in (element_to_act_on.arguments or []) + for arg in element_to_act_on.arguments or [] for key, value in variables.items() ] diff --git a/stagehand/handlers/extract_handler.py b/stagehand/handlers/extract_handler.py index ec9b9fa5..8af621c9 100644 --- a/stagehand/handlers/extract_handler.py +++ b/stagehand/handlers/extract_handler.py @@ -105,7 +105,7 @@ async def extract( schema = transformed_schema = DefaultExtractSchema # Use inference to call the LLM - extraction_response = extract_inference( + extraction_response = await extract_inference( instruction=instruction, tree_elements=output_string, schema=transformed_schema, diff --git a/stagehand/handlers/observe_handler.py b/stagehand/handlers/observe_handler.py index f0f29181..5acce6d5 100644 --- a/stagehand/handlers/observe_handler.py +++ b/stagehand/handlers/observe_handler.py @@ -74,7 +74,7 @@ async def observe( iframes = tree.get("iframes", []) # use inference to call the llm - observation_response = observe_inference( + observation_response = await observe_inference( instruction=instruction, tree_elements=output_string, llm_client=self.stagehand.llm, diff --git a/stagehand/llm/client.py b/stagehand/llm/client.py index e9fbefe5..06dc9594 100644 --- a/stagehand/llm/client.py +++ b/stagehand/llm/client.py @@ -60,7 +60,7 @@ def __init__( f"Set global litellm.api_base to {value}", category="llm" ) - def create_response( + async def create_response( self, *, messages: list[dict[str, str]], @@ -77,7 +77,7 @@ def create_response( Overrides the default_model if provided. function_name: The name of the Stagehand function calling this method (ACT, OBSERVE, etc.) Used for metrics tracking. - **kwargs: Additional parameters to pass directly to litellm.completion + **kwargs: Additional parameters to pass directly to litellm.acompletion (e.g., temperature, max_tokens, stream=True, specific provider arguments). Returns: @@ -87,7 +87,7 @@ def create_response( Raises: ValueError: If no model is specified (neither default nor in the call). - Exception: Propagates exceptions from litellm.completion. + Exception: Propagates exceptions from litellm.acompletion. """ completion_model = model or self.default_model if not completion_model: @@ -115,7 +115,7 @@ def create_response( filtered_params["temperature"] = 1 self.logger.debug( - f"Calling litellm.completion with model={completion_model} and params: {filtered_params}", + f"Calling litellm.acompletion with model={completion_model} and params: {filtered_params}", category="llm", ) @@ -123,8 +123,8 @@ def create_response( # Start tracking inference time start_time = start_inference_timer() - # Use litellm's completion function - response = litellm.completion(**filtered_params) + # Use litellm's async completion function + response = await litellm.acompletion(**filtered_params) # Calculate inference time inference_time_ms = get_inference_time_ms(start_time) @@ -136,6 +136,6 @@ def create_response( return response except Exception as e: - self.logger.error(f"Error calling litellm.completion: {e}", category="llm") + self.logger.error(f"Error calling litellm.acompletion: {e}", category="llm") # Consider more specific exception handling based on litellm errors raise diff --git a/stagehand/llm/inference.py b/stagehand/llm/inference.py index 24f0de91..b438883b 100644 --- a/stagehand/llm/inference.py +++ b/stagehand/llm/inference.py @@ -21,7 +21,7 @@ # TODO: kwargs -def observe( +async def observe( instruction: str, tree_elements: str, llm_client: Any, @@ -66,7 +66,7 @@ def observe( try: # Call the LLM logger.info("Calling LLM") - response = llm_client.create_response( + response = await llm_client.create_response( model=llm_client.default_model, messages=messages, response_format=ObserveInferenceSchema, @@ -123,7 +123,7 @@ def observe( } -def extract( +async def extract( instruction: str, tree_elements: str, schema: Optional[Union[type[BaseModel], dict]] = None, @@ -177,7 +177,7 @@ def extract( # Call the LLM with appropriate parameters try: - extract_response = llm_client.create_response( + extract_response = await llm_client.create_response( model=llm_client.default_model, messages=extract_messages, response_format=response_format, @@ -227,7 +227,7 @@ def extract( # Call LLM for metadata try: metadata_start_time = time.time() - metadata_response = llm_client.create_response( + metadata_response = await llm_client.create_response( model=llm_client.default_model, messages=metadata_messages, response_format=metadata_schema, diff --git a/tests/mocks/mock_llm.py b/tests/mocks/mock_llm.py index 7c53275e..e1cd744f 100644 --- a/tests/mocks/mock_llm.py +++ b/tests/mocks/mock_llm.py @@ -258,7 +258,7 @@ def get_usage_stats(self) -> Dict[str, int]: "total_tokens": total_prompt_tokens + total_completion_tokens } - def create_response( + async def create_response( self, *, messages: list[dict[str, str]], @@ -274,13 +274,13 @@ def create_response( # Fall back to content-based detection content = str(messages).lower() response_type = self._determine_response_type(content) - + # Track the call self.call_count += 1 self.last_messages = messages self.last_model = model or self.default_model self.last_kwargs = kwargs - + # Store call in history call_info = { "messages": messages, @@ -290,26 +290,26 @@ def create_response( "timestamp": asyncio.get_event_loop().time() } self.call_history.append(call_info) - + # Simulate failure if configured if self.should_fail: raise Exception(self.failure_message) - + # Check for custom responses first if response_type in self.custom_responses: response_data = self.custom_responses[response_type] if callable(response_data): response_data = response_data(messages, **kwargs) return self._create_response(response_data, model=self.last_model) - + # Use default response mapping response_generator = self.response_mapping.get(response_type, self._default_response) response_data = response_generator(messages, **kwargs) - + response = self._create_response(response_data, model=self.last_model) - + # Call metrics callback if set if self.metrics_callback: self.metrics_callback(response, 100, response_type) # 100ms mock inference time - + return response \ No newline at end of file