diff --git a/pyproject.toml b/pyproject.toml index 4280abac..e8e6e5ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ classifiers = [ ] dependencies = [ "openai>=2.6.1", + "jmespath>=1.0.1,<2", "pydantic>=2.11.4", "python-dotenv>=1.2.1", "requests>=2.31.0", @@ -46,7 +47,6 @@ dependencies = [ "beautifulsoup4>=4.14.2", "pylint>=3.0.0", "pytest>=8.3.5", - "jmespath>=1.0.1,<2", ] [project.optional-dependencies] @@ -60,7 +60,7 @@ web-hacker-discover = "web_hacker.scripts.discover_routine:main" web-hacker-execute = "web_hacker.scripts.execute_routine:main" [project.urls] -Homepage = "https://www.vectorly.app" +Homepage = "https://vectorly.app" Documentation = "https://github.com/VectorlyApp/web-hacker#readme" Repository = "https://github.com/VectorlyApp/web-hacker" Issues = "https://github.com/VectorlyApp/web-hacker/issues" diff --git a/scripts/run_guide_agent.py b/scripts/run_guide_agent.py new file mode 100755 index 00000000..b67293df --- /dev/null +++ b/scripts/run_guide_agent.py @@ -0,0 +1,373 @@ +#!/usr/bin/env python3 +""" +scripts/run_guide_agent.py + +Interactive terminal interface for the Guide Agent. +Guides users through creating web automation routines. +""" + +import argparse +import json +import sys +import textwrap +from typing import Any + +from web_hacker.agents.guide_agent import GuideAgent +from web_hacker.data_models.llms.vendors import OpenAIModel +from web_hacker.data_models.llms.interaction import ( + ChatMessageType, + EmittedChatMessage, + PendingToolInvocation, + ToolInvocationStatus, +) + + +# ANSI color codes +class Colors: + """ANSI escape codes for terminal colors.""" + + RESET = "\033[0m" + BOLD = "\033[1m" + DIM = "\033[2m" + ITALIC = "\033[3m" + UNDERLINE = "\033[4m" + + # Foreground colors + BLACK = "\033[30m" + RED = "\033[31m" + GREEN = "\033[32m" + YELLOW = "\033[33m" + BLUE = "\033[34m" + MAGENTA = "\033[35m" + CYAN = "\033[36m" + WHITE = "\033[37m" + + # Bright foreground + BRIGHT_BLACK = "\033[90m" + BRIGHT_RED = "\033[91m" + BRIGHT_GREEN = "\033[92m" + BRIGHT_YELLOW = "\033[93m" + BRIGHT_BLUE = "\033[94m" + BRIGHT_MAGENTA = "\033[95m" + BRIGHT_CYAN = "\033[96m" + BRIGHT_WHITE = "\033[97m" + + # Background colors + BG_BLACK = "\033[40m" + BG_RED = "\033[41m" + BG_GREEN = "\033[42m" + BG_YELLOW = "\033[43m" + BG_BLUE = "\033[44m" + BG_MAGENTA = "\033[45m" + BG_CYAN = "\033[46m" + BG_WHITE = "\033[47m" + + +def colorize(text: str, *codes: str) -> str: + """Apply ANSI color codes to text.""" + return "".join(codes) + text + Colors.RESET + + +def print_wrapped(text: str, indent: str = " ", width: int = 80) -> None: + """Print text with word wrapping and indentation.""" + lines = text.split("\n") + for line in lines: + if line.strip(): + wrapped = textwrap.fill(line, width=width, initial_indent=indent, subsequent_indent=indent) + print(wrapped) + else: + print() + + +class TerminalGuideChat: + """Interactive terminal chat interface for the Guide Agent.""" + + BANNER = r""" + ╔════════════════════════════════════════════════════════════════════╗ + ║ ║ + ║ ██╗ ██╗███████╗ ██████╗████████╗ ██████╗ ██████╗ ██╗ ██╗ ██╗ ║ + ║ ██║ ██║██╔════╝██╔════╝╚══██╔══╝██╔═══██╗██╔══██╗██║ ╚██╗ ██╔╝ ║ + ║ ██║ ██║█████╗ ██║ ██║ ██║ ██║██████╔╝██║ ╚████╔╝ ║ + ║ ╚██╗ ██╔╝██╔══╝ ██║ ██║ ██║ ██║██╔══██╗██║ ╚██╔╝ ║ + ║ ╚████╔╝ ███████╗╚██████╗ ██║ ╚██████╔╝██║ ██║███████╗██║ ║ + ║ ╚═══╝ ╚══════╝ ╚═════╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝╚══════╝╚═╝ ║ + ║ ║ + ║ Guide Agent Terminal ║ + ║ ║ + ╚════════════════════════════════════════════════════════════════════╝ + """ + + WELCOME_MESSAGE = """ + Welcome! I'll help you create a web automation routine from your + CDP (Chrome DevTools Protocol) captures. + + I'll analyze your network transactions to identify relevant API + endpoints, required cookies, headers, and request patterns that + can be turned into a reusable routine. + + Commands: + • Type your message and press Enter to chat + • Type 'quit' or 'exit' to leave + • Type 'reset' to start a new conversation + + Links: + • Docs: https://vectorly.app/docs + • Console: https://console.vectorly.app + """ + + def __init__(self, llm_model: OpenAIModel | None = None) -> None: + """Initialize the terminal chat interface.""" + self._pending_invocation: PendingToolInvocation | None = None + self._streaming_started: bool = False + self._agent = GuideAgent( + emit_message_callable=self._handle_message, + stream_chunk_callable=self._handle_stream_chunk, + llm_model=llm_model if llm_model else OpenAIModel.GPT_5_MINI, + ) + + def _handle_stream_chunk(self, chunk: str) -> None: + """ + Handle a streaming text chunk from the LLM. + + Args: + chunk: A text chunk from the streaming response. + """ + if not self._streaming_started: + # Print the header before the first chunk + print() + print(colorize(" Assistant", Colors.BOLD, Colors.CYAN) + colorize(":", Colors.DIM)) + print() + print(" ", end="", flush=True) + self._streaming_started = True + + # Print chunk without newline, flush immediately + print(chunk, end="", flush=True) + + def _handle_message(self, message: EmittedChatMessage) -> None: + """ + Handle messages emitted by the Guide Agent. + + Args: + message: The emitted message from the agent. + """ + if message.type == ChatMessageType.CHAT_RESPONSE: + # If we were streaming, just finish with newlines (content already printed) + if self._streaming_started: + print() # End the streamed line + print() # Add spacing + self._streaming_started = False + else: + self._print_assistant_message(message.content or "") + + elif message.type == ChatMessageType.TOOL_INVOCATION_REQUEST: + if message.tool_invocation: + self._pending_invocation = message.tool_invocation + self._print_tool_request(message.tool_invocation) + + elif message.type == ChatMessageType.TOOL_INVOCATION_RESULT: + if message.tool_invocation: + self._print_tool_result( + message.tool_invocation, + message.tool_result, + message.error, + ) + + elif message.type == ChatMessageType.ERROR: + self._print_error(message.error or "Unknown error") + + def _print_assistant_message(self, content: str) -> None: + """Print an assistant response.""" + print() + print(colorize(" Assistant", Colors.BOLD, Colors.CYAN) + colorize(":", Colors.DIM)) + print() + print_wrapped(content, indent=" ") + print() + + def _print_tool_request(self, invocation: PendingToolInvocation) -> None: + """Print a tool invocation request with formatted arguments.""" + print() + print(colorize(" ┌─────────────────────────────────────────────────────────────────┐", Colors.YELLOW)) + print(colorize(" │", Colors.YELLOW) + colorize(" TOOL INVOCATION REQUEST", Colors.BOLD, Colors.YELLOW) + colorize(" │", Colors.YELLOW)) + print(colorize(" ├─────────────────────────────────────────────────────────────────┤", Colors.YELLOW)) + print(colorize(" │", Colors.YELLOW)) + + # Tool name + print(colorize(" │ ", Colors.YELLOW) + colorize("Tool: ", Colors.DIM) + colorize(invocation.tool_name, Colors.BRIGHT_WHITE, Colors.BOLD)) + + # Arguments + print(colorize(" │", Colors.YELLOW)) + print(colorize(" │ ", Colors.YELLOW) + colorize("Arguments:", Colors.DIM)) + + args_json = json.dumps(invocation.tool_arguments, indent=4) + for line in args_json.split("\n"): + print(colorize(" │ ", Colors.YELLOW) + colorize(line, Colors.WHITE)) + + print(colorize(" │", Colors.YELLOW)) + print(colorize(" └─────────────────────────────────────────────────────────────────┘", Colors.YELLOW)) + print() + print(colorize(" Do you want to proceed? ", Colors.BRIGHT_YELLOW) + colorize("[y/n]", Colors.DIM) + ": ", end="") + + def _print_tool_result( + self, + invocation: PendingToolInvocation, + result: dict[str, Any] | None, + error: str | None, + ) -> None: + """Print a tool invocation result.""" + print() + + if invocation.status == ToolInvocationStatus.DENIED: + print(colorize(" ✗ Tool invocation denied", Colors.YELLOW)) + + elif invocation.status == ToolInvocationStatus.EXECUTED: + print(colorize(" ✓ Tool executed successfully", Colors.GREEN, Colors.BOLD)) + if result: + print() + print(colorize(" Result:", Colors.DIM)) + result_json = json.dumps(result, indent=4) + for line in result_json.split("\n"): + print(colorize(" " + line, Colors.GREEN)) + + elif invocation.status == ToolInvocationStatus.FAILED: + print(colorize(" ✗ Tool execution failed", Colors.RED, Colors.BOLD)) + if error: + print(colorize(f" Error: {error}", Colors.RED)) + + print() + + def _print_error(self, error: str) -> None: + """Print an error message.""" + print() + print(colorize(" ⚠ Error: ", Colors.RED, Colors.BOLD) + colorize(error, Colors.RED)) + print() + + def _print_user_prompt(self) -> None: + """Print the user input prompt.""" + print(colorize(" You", Colors.BOLD, Colors.GREEN) + colorize(": ", Colors.DIM), end="") + + def _handle_tool_confirmation(self, user_input: str) -> bool: + """ + Handle yes/no confirmation for pending tool invocation. + + Args: + user_input: The user's input. + + Returns: + True if the confirmation was handled, False otherwise. + """ + if not self._pending_invocation: + return False + + normalized = user_input.strip().lower() + + if normalized in ("y", "yes"): + invocation_id = self._pending_invocation.invocation_id + self._pending_invocation = None + self._agent.confirm_tool_invocation(invocation_id) + return True + + elif normalized in ("n", "no"): + invocation_id = self._pending_invocation.invocation_id + self._pending_invocation = None + self._agent.deny_tool_invocation(invocation_id, reason="User declined") + return True + + else: + print(colorize(" Please enter 'y' or 'n': ", Colors.YELLOW), end="") + return True # Still in confirmation mode + + def run(self) -> None: + """Run the interactive chat loop.""" + # Print banner and welcome + print(colorize(self.BANNER, Colors.BRIGHT_MAGENTA, Colors.BOLD)) + print(colorize(self.WELCOME_MESSAGE, Colors.DIM)) + print(colorize(f" Model: {self._agent.llm_model}", Colors.DIM)) + print() + print(colorize(" " + "─" * 67, Colors.DIM)) + print() + + while True: + try: + # Handle pending tool confirmation + if self._pending_invocation: + user_input = input() + if self._handle_tool_confirmation(user_input): + if not self._pending_invocation: + # Confirmation was processed, continue to next iteration + continue + else: + # Still waiting for valid y/n + continue + else: + self._print_user_prompt() + user_input = input() + + # Check for commands + normalized = user_input.strip().lower() + + if normalized in ("quit", "exit", "q"): + print() + print(colorize(" Goodbye! 👋", Colors.CYAN, Colors.BOLD)) + print() + break + + if normalized == "reset": + self._agent.reset() + self._pending_invocation = None + print() + print(colorize(" ↺ Conversation reset", Colors.YELLOW)) + print() + continue + + if not user_input.strip(): + continue + + # Process the message + self._agent.process_user_message(user_input) + + except KeyboardInterrupt: + print() + print(colorize("\n Interrupted. Goodbye! 👋", Colors.CYAN)) + print() + break + + except EOFError: + print() + print(colorize("\n Goodbye! 👋", Colors.CYAN)) + print() + break + + +def parse_model(model_str: str) -> OpenAIModel: + """Parse a model string into an OpenAIModel enum value.""" + for model in OpenAIModel: + if model.value == model_str or model.name == model_str: + return model + raise ValueError(f"Unknown model: {model_str}") + + +def main() -> None: + """Entry point for the guide agent terminal.""" + parser = argparse.ArgumentParser(description="Interactive Guide Agent terminal") + parser.add_argument( + "--model", + type=str, + default=OpenAIModel.GPT_5_MINI.value, + help=f"LLM model to use (default: {OpenAIModel.GPT_5_MINI.value})", + ) + args = parser.parse_args() + + try: + llm_model = parse_model(args.model) + chat = TerminalGuideChat(llm_model=llm_model) + chat.run() + except ValueError as e: + print(colorize(f"\n Error: {e}", Colors.RED, Colors.BOLD), file=sys.stderr) + sys.exit(1) + except Exception as e: + print(colorize(f"\n Fatal error: {e}", Colors.RED, Colors.BOLD), file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tests/conftest.py b/tests/conftest.py index ff4d7b1e..d4c50788 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Any +from unittest.mock import MagicMock, patch import pytest @@ -13,6 +14,24 @@ from web_hacker.data_models.routine.operation import RoutineOperationUnion +@pytest.fixture(autouse=True) +def mock_openai_clients() -> dict[str, MagicMock]: + """ + Mock OpenAI clients to avoid needing real API keys in tests. + + This fixture automatically patches the OpenAI and AsyncOpenAI classes + wherever they're imported in the openai_client module, so tests can + instantiate OpenAIClient without a valid API key. + """ + with ( + patch("web_hacker.llms.openai_client.OpenAI") as mock_openai, + patch("web_hacker.llms.openai_client.AsyncOpenAI") as mock_async_openai, + ): + mock_openai.return_value = MagicMock() + mock_async_openai.return_value = MagicMock() + yield {"openai": mock_openai, "async_openai": mock_async_openai} + + @pytest.fixture(scope="session") def tests_root() -> Path: """ diff --git a/tests/unit/test_openai_client.py b/tests/unit/test_openai_client.py new file mode 100644 index 00000000..cf803496 --- /dev/null +++ b/tests/unit/test_openai_client.py @@ -0,0 +1,203 @@ +""" +tests/unit/test_openai_client.py + +Unit tests for OpenAI client validation logic. +""" + +import pytest + +from web_hacker.data_models.llms.vendors import OpenAIAPIType, OpenAIModel +from web_hacker.llms.openai_client import OpenAIClient + + +class TestValidateAndResolveAPIType: + """Tests for _validate_and_resolve_api_type method.""" + + @pytest.fixture + def client(self) -> OpenAIClient: + """Create an OpenAIClient instance for testing.""" + return OpenAIClient(model=OpenAIModel.GPT_5_MINI) + + # Happy path tests - valid combinations + + def test_default_resolves_to_chat_completions(self, client: OpenAIClient) -> None: + """Default (no special params) should resolve to Chat Completions API.""" + result = client._validate_and_resolve_api_type( + api_type=None, + extended_reasoning=False, + previous_response_id=None, + ) + assert result == OpenAIAPIType.CHAT_COMPLETIONS + + def test_extended_reasoning_auto_resolves_to_responses(self, client: OpenAIClient) -> None: + """extended_reasoning=True should auto-resolve to Responses API.""" + result = client._validate_and_resolve_api_type( + api_type=None, + extended_reasoning=True, + previous_response_id=None, + ) + assert result == OpenAIAPIType.RESPONSES + + def test_previous_response_id_auto_resolves_to_responses(self, client: OpenAIClient) -> None: + """previous_response_id should auto-resolve to Responses API.""" + result = client._validate_and_resolve_api_type( + api_type=None, + extended_reasoning=False, + previous_response_id="resp_123", + ) + assert result == OpenAIAPIType.RESPONSES + + def test_explicit_chat_completions_api_type(self, client: OpenAIClient) -> None: + """Explicit Chat Completions API type should be honored.""" + result = client._validate_and_resolve_api_type( + api_type=OpenAIAPIType.CHAT_COMPLETIONS, + extended_reasoning=False, + previous_response_id=None, + ) + assert result == OpenAIAPIType.CHAT_COMPLETIONS + + def test_explicit_responses_api_type(self, client: OpenAIClient) -> None: + """Explicit Responses API type should be honored.""" + result = client._validate_and_resolve_api_type( + api_type=OpenAIAPIType.RESPONSES, + extended_reasoning=False, + previous_response_id=None, + ) + assert result == OpenAIAPIType.RESPONSES + + def test_extended_reasoning_with_responses_api_valid(self, client: OpenAIClient) -> None: + """extended_reasoning=True with explicit Responses API should work.""" + result = client._validate_and_resolve_api_type( + api_type=OpenAIAPIType.RESPONSES, + extended_reasoning=True, + previous_response_id=None, + ) + assert result == OpenAIAPIType.RESPONSES + + def test_previous_response_id_with_responses_api_valid(self, client: OpenAIClient) -> None: + """previous_response_id with explicit Responses API should work.""" + result = client._validate_and_resolve_api_type( + api_type=OpenAIAPIType.RESPONSES, + extended_reasoning=False, + previous_response_id="resp_123", + ) + assert result == OpenAIAPIType.RESPONSES + + # Error cases - invalid combinations + + def test_extended_reasoning_with_chat_completions_raises_error(self, client: OpenAIClient) -> None: + """extended_reasoning=True with Chat Completions API should raise ValueError.""" + with pytest.raises(ValueError, match="extended_reasoning=True requires Responses API"): + client._validate_and_resolve_api_type( + api_type=OpenAIAPIType.CHAT_COMPLETIONS, + extended_reasoning=True, + previous_response_id=None, + ) + + def test_previous_response_id_with_chat_completions_raises_error(self, client: OpenAIClient) -> None: + """previous_response_id with Chat Completions API should raise ValueError.""" + with pytest.raises(ValueError, match="previous_response_id requires Responses API"): + client._validate_and_resolve_api_type( + api_type=OpenAIAPIType.CHAT_COMPLETIONS, + extended_reasoning=False, + previous_response_id="resp_123", + ) + + def test_both_extended_reasoning_and_previous_response_id_with_chat_completions_raises_error( + self, + client: OpenAIClient, + ) -> None: + """Both extended_reasoning and previous_response_id with Chat Completions should raise ValueError.""" + with pytest.raises(ValueError): + client._validate_and_resolve_api_type( + api_type=OpenAIAPIType.CHAT_COMPLETIONS, + extended_reasoning=True, + previous_response_id="resp_123", + ) + + +class TestToolRegistration: + """Tests for tool registration.""" + + @pytest.fixture + def client(self) -> OpenAIClient: + """Create an OpenAIClient instance for testing.""" + return OpenAIClient(model=OpenAIModel.GPT_5_MINI) + + def test_register_tool(self, client: OpenAIClient) -> None: + """Test that tools are registered correctly.""" + client.register_tool( + name="test_tool", + description="A test tool", + parameters={"type": "object", "properties": {}}, + ) + + assert len(client.tools) == 1 + assert client.tools[0]["type"] == "function" + assert client.tools[0]["function"]["name"] == "test_tool" + assert client.tools[0]["function"]["description"] == "A test tool" + + def test_clear_tools(self, client: OpenAIClient) -> None: + """Test that tools can be cleared.""" + client.register_tool( + name="test_tool", + description="A test tool", + parameters={"type": "object", "properties": {}}, + ) + assert len(client.tools) == 1 + + client.clear_tools() + assert len(client.tools) == 0 + + +class TestCallSyncValidation: + """Tests for call_sync parameter validation.""" + + @pytest.fixture + def client(self) -> OpenAIClient: + """Create an OpenAIClient instance for testing.""" + return OpenAIClient(model=OpenAIModel.GPT_5_MINI) + + def test_messages_required_for_chat_completions(self, client: OpenAIClient) -> None: + """Test that messages is required for Chat Completions API.""" + with pytest.raises(ValueError, match="messages is required for Chat Completions API"): + client.call_sync( + messages=None, + input="Hello", + api_type=OpenAIAPIType.CHAT_COMPLETIONS, + ) + + def test_extended_reasoning_with_chat_completions_raises_error(self, client: OpenAIClient) -> None: + """Test that extended_reasoning with Chat Completions raises ValueError.""" + with pytest.raises(ValueError, match="extended_reasoning=True requires Responses API"): + client.call_sync( + messages=[{"role": "user", "content": "Hello"}], + extended_reasoning=True, + api_type=OpenAIAPIType.CHAT_COMPLETIONS, + ) + + +class TestLLMChatResponseFields: + """Tests for LLMChatResponse new fields.""" + + def test_response_id_field_exists(self) -> None: + """Test that response_id field exists on LLMChatResponse.""" + from web_hacker.data_models.llms.interaction import LLMChatResponse + + response = LLMChatResponse(content="test", response_id="resp_123") + assert response.response_id == "resp_123" + + def test_reasoning_content_field_exists(self) -> None: + """Test that reasoning_content field exists on LLMChatResponse.""" + from web_hacker.data_models.llms.interaction import LLMChatResponse + + response = LLMChatResponse(content="test", reasoning_content="I thought about this...") + assert response.reasoning_content == "I thought about this..." + + def test_default_values_are_none(self) -> None: + """Test that new fields default to None.""" + from web_hacker.data_models.llms.interaction import LLMChatResponse + + response = LLMChatResponse(content="test") + assert response.response_id is None + assert response.reasoning_content is None diff --git a/tests/unit/test_tool_utils.py b/tests/unit/test_tool_utils.py new file mode 100644 index 00000000..b5585e04 --- /dev/null +++ b/tests/unit/test_tool_utils.py @@ -0,0 +1,106 @@ +""" +tests/unit/test_tool_utils.py + +Unit tests for LLM tool utilities. +""" + +from web_hacker.llms.tools.tool_utils import ( + extract_description_from_docstring, + generate_parameters_schema, +) + + +class TestExtractDescriptionFromDocstring: + """Tests for docstring description extraction.""" + + def test_single_line_docstring(self) -> None: + docstring = "This is a simple description." + result = extract_description_from_docstring(docstring) + assert result == "This is a simple description." + + def test_multiline_first_paragraph(self) -> None: + docstring = """This is a description + that spans multiple lines.""" + result = extract_description_from_docstring(docstring) + assert result == "This is a description that spans multiple lines." + + def test_extracts_only_first_paragraph(self) -> None: + docstring = """First paragraph here. + + Args: + foo: Some argument. + """ + result = extract_description_from_docstring(docstring) + assert result == "First paragraph here." + + def test_none_docstring(self) -> None: + result = extract_description_from_docstring(None) + assert result == "" + + def test_empty_docstring(self) -> None: + result = extract_description_from_docstring("") + assert result == "" + + def test_strips_leading_whitespace(self) -> None: + docstring = """ + Description with leading whitespace. + """ + result = extract_description_from_docstring(docstring) + assert result == "Description with leading whitespace." + + +class TestGenerateParametersSchema: + """Tests for function parameter schema generation.""" + + def test_simple_string_params(self) -> None: + def example_func(name: str, value: str) -> None: + pass + + schema = generate_parameters_schema(example_func) + assert schema["type"] == "object" + assert schema["required"] == ["name", "value"] + assert schema["properties"]["name"]["type"] == "string" + assert schema["properties"]["value"]["type"] == "string" + + def test_optional_params_not_required(self) -> None: + def example_func(required_param: str, optional_param: str | None = None) -> None: + pass + + schema = generate_parameters_schema(example_func) + assert schema["required"] == ["required_param"] + assert "optional_param" in schema["properties"] + assert "optional_param" not in schema["required"] + + def test_list_type(self) -> None: + def example_func(items: list[str]) -> None: + pass + + schema = generate_parameters_schema(example_func) + assert schema["properties"]["items"]["type"] == "array" + assert schema["properties"]["items"]["items"]["type"] == "string" + + def test_dict_type(self) -> None: + def example_func(data: dict[str, int]) -> None: + pass + + schema = generate_parameters_schema(example_func) + props = schema["properties"]["data"] + assert props["type"] == "object" + assert props["additionalProperties"]["type"] == "integer" + + def test_skips_self_parameter(self) -> None: + class Example: + def method(self, name: str) -> None: + pass + + schema = generate_parameters_schema(Example.method) + assert "self" not in schema["properties"] + assert schema["required"] == ["name"] + + def test_nullable_type_uses_anyof(self) -> None: + def example_func(value: str | None) -> None: + pass + + schema = generate_parameters_schema(example_func) + # pydantic represents str | None as anyOf + assert "anyOf" in schema["properties"]["value"] diff --git a/uv.lock b/uv.lock index ddc66198..b7e6a9ac 100644 --- a/uv.lock +++ b/uv.lock @@ -905,7 +905,7 @@ wheels = [ [[package]] name = "web-hacker" -version = "1.2.2" +version = "1.2.3" source = { editable = "." } dependencies = [ { name = "beautifulsoup4" }, @@ -931,7 +931,7 @@ dev = [ requires-dist = [ { name = "beautifulsoup4", specifier = ">=4.14.2" }, { name = "ipykernel", marker = "extra == 'dev'", specifier = ">=6.29.5" }, - { name = "jmespath", specifier = ">=1.0.1" }, + { name = "jmespath", specifier = ">=1.0.1,<2" }, { name = "openai", specifier = ">=2.6.1" }, { name = "pydantic", specifier = ">=2.11.4" }, { name = "pylint", specifier = ">=3.0.0" }, diff --git a/web_hacker/agents/__init__.py b/web_hacker/agents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/web_hacker/agents/guide_agent.py b/web_hacker/agents/guide_agent.py new file mode 100644 index 00000000..ae6455c6 --- /dev/null +++ b/web_hacker/agents/guide_agent.py @@ -0,0 +1,559 @@ +""" +web_hacker/agents/guide_agent.py + +Guide agent that guides the user through the process of creating or editing a routine. +""" + +from datetime import datetime +from typing import Any, Callable +from uuid import uuid4 + +from web_hacker.data_models.llms.interaction import ( + Chat, + ChatMessageType, + ChatRole, + ChatThread, + EmittedChatMessage, + LLMChatResponse, + PendingToolInvocation, + ToolInvocationStatus, +) +from web_hacker.data_models.llms.vendors import OpenAIModel +from web_hacker.llms.llm_client import LLMClient +from web_hacker.llms.tools.guide_agent_tools import start_routine_discovery_job_creation +from web_hacker.utils.exceptions import UnknownToolError +from web_hacker.utils.logger import get_logger + + +logger = get_logger(name=__name__) + + +class GuideAgent: + """ + Guide agent that guides the user through the process of creating or editing a routine. + + The agent maintains a ChatThread with Chat messages and uses LLM tool-calling to determine + when to initiate routine discovery. Tool invocations require user confirmation + via callback before execution. + + Usage: + def handle_message(message: EmittedChatMessage) -> None: + print(f"[{message.type}] {message.content}") + + agent = GuideAgent(emit_message_callable=handle_message) + agent.process_user_message("I want to search for flights") + """ + + # Class constants ______________________________________________________________________________________________________ + + SYSTEM_PROMPT: str = """You are a helpful assistant that guides users through creating \ +web automation routines using the Web Hacker tool. + +## What is Web Hacker? + +Web Hacker is a tool that creates reusable web automation routines by learning from \ +user demonstrations. Users record themselves performing a task on a website, and \ +Web Hacker generates a parameterized routine that can be executed programmatically. + +## Your Role + +Your job is to help users define their automation needs by gathering: + +1. **TASK**: What task do they want to automate? + - Examples: "Search for train tickets", "Download a research paper", "Look up company info" + +2. **OUTPUT**: What data/output should the routine return? + - Examples: "List of available trains with prices", "PDF file of the paper", "Company registration details" + +3. **PARAMETERS**: What input parameters will the routine need? For each: + - Name (e.g., "origin", "destination", "departure_date") + - Description (e.g., "The departure station name") + +4. **CONSTRAINTS**: Any filters or constraints? + - Examples: "Only direct trains", "Papers from 2024", "Only active companies" + +5. **WEBSITE**: What website should be used? + - Examples: "amtrak.com", "arxiv.org", "sec.gov" + +## Guidelines + +- Be conversational and helpful +- Ask clarifying questions if needed +- Don't overwhelm the user with too many questions at once +- When you have enough information, use the start_routine_discovery_job_creation tool +- If the user asks what this tool does, explain it clearly""" + + # Magic methods ________________________________________________________________________________________________________ + + def __init__( + self, + emit_message_callable: Callable[[EmittedChatMessage], None], + persist_chat_callable: Callable[[Chat], Chat] | None = None, + persist_chat_thread_callable: Callable[[ChatThread], ChatThread] | None = None, + stream_chunk_callable: Callable[[str], None] | None = None, + llm_model: OpenAIModel = OpenAIModel.GPT_5_MINI, + chat_thread: ChatThread | None = None, + existing_chats: list[Chat] | None = None, + ) -> None: + """ + Initialize the guide agent. + + Args: + emit_message_callable: Callback function to emit messages to the host. + persist_chat_callable: Optional callback to persist Chat objects (for DynamoDB). + Returns the Chat with the final ID assigned by the persistence layer. + persist_chat_thread_callable: Optional callback to persist ChatThread (for DynamoDB). + Returns the ChatThread with the final ID assigned by the persistence layer. + stream_chunk_callable: Optional callback for streaming text chunks as they arrive. + llm_model: The LLM model to use for conversation. + chat_thread: Existing ChatThread to continue, or None for new conversation. + existing_chats: Existing Chat messages if loading from persistence. + """ + self._emit_message_callable = emit_message_callable + self._persist_chat_callable = persist_chat_callable + self._persist_chat_thread_callable = persist_chat_thread_callable + self._stream_chunk_callable = stream_chunk_callable + + self.llm_model = llm_model + self.llm_client = LLMClient(llm_model) + + # Register tools + self._register_tools() + + # Initialize or load conversation state + self._thread = chat_thread or ChatThread() + self._chats: dict[str, Chat] = {} + if existing_chats: + for chat in existing_chats: + self._chats[chat.id] = chat + + # Persist initial thread if callback provided + if self._persist_chat_thread_callable and chat_thread is None: + self._thread = self._persist_chat_thread_callable(self._thread) + + logger.info( + "Instantiated GuideAgent with model: %s, thread_id: %s", + llm_model, + self._thread.id, + ) + + # Properties ___________________________________________________________________________________________________________ + + @property + def thread_id(self) -> str: + """Return the current thread ID.""" + return self._thread.id + + @property + def has_pending_tool_invocation(self) -> bool: + """Check if there's a pending tool invocation awaiting confirmation.""" + return self._thread.pending_tool_invocation is not None + + # Private methods ______________________________________________________________________________________________________ + + def _register_tools(self) -> None: + """Register all tools with the LLM client.""" + self.llm_client.register_tool_from_function( + func=start_routine_discovery_job_creation, + ) + + def _emit_message(self, message: EmittedChatMessage) -> None: + """Emit a message via the callback.""" + self._emit_message_callable(message) + + def _add_chat(self, role: ChatRole, content: str) -> Chat: + """ + Create and store a new Chat, update thread, persist if callbacks set. + + Args: + role: The role of the message sender. + content: The content of the message. + + Returns: + The created Chat object (with final ID from persistence layer if callback provided). + """ + chat = Chat( + thread_id=self._thread.id, + role=role, + content=content, + ) + + # Persist chat first if callback provided (may assign new ID) + if self._persist_chat_callable: + chat = self._persist_chat_callable(chat) + + # Store with final ID + self._chats[chat.id] = chat + self._thread.message_ids.append(chat.id) + self._thread.updated_at = int(datetime.now().timestamp()) + + # Persist thread if callback provided + if self._persist_chat_thread_callable: + self._thread = self._persist_chat_thread_callable(self._thread) + + return chat + + def _build_messages_for_llm(self) -> list[dict[str, str]]: + """ + Build messages list for LLM from Chat objects. + + Returns: + List of message dicts with 'role' and 'content' keys. + """ + messages: list[dict[str, str]] = [] + for chat_id in self._thread.message_ids: + chat = self._chats.get(chat_id) + if chat: + messages.append({ + "role": chat.role.value, + "content": chat.content, + }) + return messages + + def _create_tool_invocation_request( + self, + tool_name: str, + tool_arguments: dict[str, Any], + ) -> PendingToolInvocation: + """ + Create a tool invocation request for user confirmation. + + Args: + tool_name: Name of the tool to invoke + tool_arguments: Arguments for the tool + + Returns: + PendingToolInvocation stored in state and ready to emit + """ + invocation_id = str(uuid4()) + pending = PendingToolInvocation( + invocation_id=invocation_id, + tool_name=tool_name, + tool_arguments=tool_arguments, + ) + + # Store in thread + self._thread.pending_tool_invocation = pending + self._thread.updated_at = int(datetime.now().timestamp()) + + if self._persist_chat_thread_callable: + self._thread = self._persist_chat_thread_callable(self._thread) + + logger.info( + "Created tool invocation request: %s (tool: %s)", + invocation_id, + tool_name, + ) + + return pending + + def _execute_tool( + self, + tool_name: str, + tool_arguments: dict[str, Any], + ) -> dict[str, Any]: + """ + Execute a confirmed tool invocation. + + Args: + tool_name: Name of the tool to execute + tool_arguments: Arguments for the tool + + Returns: + Tool execution result with thread_id and params + + Raises: + UnknownToolError: If tool_name is unknown + """ + if tool_name == start_routine_discovery_job_creation.__name__: + logger.info( + "Executing tool %s with args: %s", + tool_name, + tool_arguments, + ) + result = start_routine_discovery_job_creation(**tool_arguments) + return { + "thread_id": self._thread.id, + **result, + } + + logger.error("Unknown tool \"%s\" with arguments: %s", tool_name, tool_arguments) + raise UnknownToolError(f"Unknown tool \"{tool_name}\" with arguments: {tool_arguments}") + + # Public methods _______________________________________________________________________________________________________ + + def process_user_message(self, content: str) -> None: + """ + Process a user message and emit responses via callback. + + This method handles the conversation loop: + 1. Adds user message to history + 2. Calls LLM to generate response + 3. Emits chat response or tool invocation request + + Args: + content: The user's message content + """ + # Block new messages if there's a pending tool invocation + if self._thread.pending_tool_invocation: + self._emit_message( + EmittedChatMessage( + type=ChatMessageType.ERROR, + error="Please confirm or deny the pending tool invocation before sending new messages", + ) + ) + return + + # Add user message to history + self._add_chat(ChatRole.USER, content) + + # Build messages and call LLM + messages = self._build_messages_for_llm() + + try: + # Use streaming if chunk callback is set + if self._stream_chunk_callable: + response = self._process_streaming_response(messages) + else: + response = self.llm_client.call_sync( + messages=messages, + system_prompt=self.SYSTEM_PROMPT, + ) + + # Handle text response + if response.content: + chat = self._add_chat(ChatRole.ASSISTANT, response.content) + # Always emit CHAT_RESPONSE (handler checks if streaming occurred) + self._emit_message( + EmittedChatMessage( + type=ChatMessageType.CHAT_RESPONSE, + content=response.content, + chat_id=chat.id, + chat_thread_id=self._thread.id, + ) + ) + + # Handle tool call if present + if response.tool_call: + pending = self._create_tool_invocation_request( + response.tool_call.tool_name, + response.tool_call.tool_arguments, + ) + self._emit_message( + EmittedChatMessage( + type=ChatMessageType.TOOL_INVOCATION_REQUEST, + tool_invocation=pending, + ) + ) + + except Exception as e: + logger.exception("Error processing user message: %s", e) + self._emit_message( + EmittedChatMessage( + type=ChatMessageType.ERROR, + error=str(e), + ) + ) + + def _process_streaming_response(self, messages: list[dict[str, str]]) -> LLMChatResponse: + """ + Process LLM response with streaming, calling chunk callback for each chunk. + + Args: + messages: The messages to send to the LLM. + + Returns: + The final LLMChatResponse with complete content. + """ + response: LLMChatResponse | None = None + + for item in self.llm_client.call_stream_sync( + messages=messages, + system_prompt=self.SYSTEM_PROMPT, + ): + if isinstance(item, str): + # Text chunk - call the callback + if self._stream_chunk_callable: + self._stream_chunk_callable(item) + elif isinstance(item, LLMChatResponse): + # Final response + response = item + + if response is None: + raise ValueError("No final response received from streaming LLM") + + return response + + def confirm_tool_invocation(self, invocation_id: str) -> None: + """ + Confirm a pending tool invocation and execute it. + + Args: + invocation_id: ID of the tool invocation to confirm + + Emits: + - TOOL_INVOCATION_RESULT with status "executed" and result on success + - ERROR if no pending invocation or ID mismatch + """ + pending = self._thread.pending_tool_invocation + + if not pending: + self._emit_message( + EmittedChatMessage( + type=ChatMessageType.ERROR, + error="No pending tool invocation to confirm", + ) + ) + return + + if pending.invocation_id != invocation_id: + self._emit_message( + EmittedChatMessage( + type=ChatMessageType.ERROR, + error=f"Invocation ID mismatch: expected {pending.invocation_id}", + ) + ) + return + + # Update status + pending.status = ToolInvocationStatus.CONFIRMED + + try: + result = self._execute_tool(pending.tool_name, pending.tool_arguments) + pending.status = ToolInvocationStatus.EXECUTED + self._thread.pending_tool_invocation = None + self._thread.updated_at = int(datetime.now().timestamp()) + + if self._persist_chat_thread_callable: + self._thread = self._persist_chat_thread_callable(self._thread) + + self._emit_message( + EmittedChatMessage( + type=ChatMessageType.TOOL_INVOCATION_RESULT, + tool_invocation=pending, + tool_result=result, + ) + ) + + logger.info( + "Tool invocation %s executed successfully", + invocation_id, + ) + + except Exception as e: + pending.status = ToolInvocationStatus.FAILED + + self._emit_message( + EmittedChatMessage( + type=ChatMessageType.TOOL_INVOCATION_RESULT, + tool_invocation=pending, + error=str(e), + ) + ) + + logger.exception( + "Tool invocation %s failed: %s", + invocation_id, + e, + ) + + def deny_tool_invocation( + self, + invocation_id: str, + reason: str | None = None, + ) -> None: + """ + Deny a pending tool invocation. + + Args: + invocation_id: ID of the tool invocation to deny + reason: Optional reason for denial + + Emits: + - TOOL_INVOCATION_RESULT with status "denied" + - ERROR if no pending invocation or ID mismatch + """ + pending = self._thread.pending_tool_invocation + + if not pending: + self._emit_message( + EmittedChatMessage( + type=ChatMessageType.ERROR, + error="No pending tool invocation to deny", + ) + ) + return + + if pending.invocation_id != invocation_id: + self._emit_message( + EmittedChatMessage( + type=ChatMessageType.ERROR, + error=f"Invocation ID mismatch: expected {pending.invocation_id}", + ) + ) + return + + # Update status and clear pending + pending.status = ToolInvocationStatus.DENIED + self._thread.pending_tool_invocation = None + self._thread.updated_at = int(datetime.now().timestamp()) + + if self._persist_chat_thread_callable: + self._thread = self._persist_chat_thread_callable(self._thread) + + # Add denial to conversation history + denial_message = "Tool invocation denied" + if reason: + denial_message += f": {reason}" + self._add_chat(ChatRole.SYSTEM, denial_message) + + self._emit_message( + EmittedChatMessage( + type=ChatMessageType.TOOL_INVOCATION_RESULT, + tool_invocation=pending, + content=denial_message, + ) + ) + + logger.info( + "Tool invocation %s denied: %s", + invocation_id, + reason or "no reason provided", + ) + + def get_thread(self) -> ChatThread: + """ + Get the current conversation thread. + + Returns: + Current ChatThread + """ + return self._thread + + def get_chats(self) -> list[Chat]: + """ + Get all Chat messages in order. + + Returns: + List of Chat objects in conversation order. + """ + return [self._chats[chat_id] for chat_id in self._thread.message_ids if chat_id in self._chats] + + def reset(self) -> None: + """ + Reset the conversation to a fresh state. + + Generates a new thread and clears all messages. + """ + old_thread_id = self._thread.id + self._thread = ChatThread() + self._chats = {} + + if self._persist_chat_thread_callable: + self._thread = self._persist_chat_thread_callable(self._thread) + + logger.info( + "Reset conversation from %s to %s", + old_thread_id, + self._thread.id, + ) diff --git a/web_hacker/cdp/connection.py b/web_hacker/cdp/connection.py index 30712c8d..645ad5b3 100644 --- a/web_hacker/cdp/connection.py +++ b/web_hacker/cdp/connection.py @@ -10,7 +10,6 @@ """ import json -import random import time from json import JSONDecodeError from typing import Callable diff --git a/web_hacker/data_models/llms/__init__.py b/web_hacker/data_models/llms/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/web_hacker/data_models/llms/interaction.py b/web_hacker/data_models/llms/interaction.py new file mode 100644 index 00000000..610694fe --- /dev/null +++ b/web_hacker/data_models/llms/interaction.py @@ -0,0 +1,192 @@ +""" +web_hacker/data_models/llms/interaction.py + +Data models for LLM interactions and agent communication. +""" + +from datetime import datetime, timezone +from enum import StrEnum +from typing import Any +from uuid import uuid4 + +from pydantic import BaseModel, Field + + +class ChatRole(StrEnum): + """ + Role in a chat message. + """ + USER = "user" + ASSISTANT = "assistant" # AI + SYSTEM = "system" + TOOL = "tool" + + +class ToolInvocationStatus(StrEnum): + """ + Status of a tool invocation. + """ + PENDING_CONFIRMATION = "pending_confirmation" + CONFIRMED = "confirmed" + DENIED = "denied" + EXECUTED = "executed" + FAILED = "failed" + + +class PendingToolInvocation(BaseModel): + """ + A tool invocation awaiting user confirmation. + """ + invocation_id: str = Field( + ..., + description="Unique ID for this invocation (UUIDv4)", + ) + tool_name: str = Field( + ..., + description="Name of the tool to invoke", + ) + tool_arguments: dict[str, Any] = Field( + ..., + description="Arguments to pass to the tool", + ) + status: ToolInvocationStatus = Field( + default=ToolInvocationStatus.PENDING_CONFIRMATION, + description="Current status of the invocation", + ) + created_at: datetime = Field( + default_factory=lambda: datetime.now(tz=timezone.utc), + description="When the invocation was created", + ) + + +class ChatMessageType(StrEnum): + """ + Types of messages the chat can emit via callback. + """ + + CHAT_RESPONSE = "chat_response" + TOOL_INVOCATION_REQUEST = "tool_invocation_request" + TOOL_INVOCATION_RESULT = "tool_invocation_result" + ERROR = "error" + + +class EmittedChatMessage(BaseModel): + """ + Message emitted by the guide agent via callback. + + This is the internal message format used by GuideAgent to communicate + with its host (e.g., CLI, WebSocket handler in servers repo). + """ + + type: ChatMessageType = Field( + ..., + description="The type of message being emitted", + ) + timestamp: datetime = Field( + default_factory=lambda: datetime.now(tz=timezone.utc), + description="When the message was created", + ) + content: str | None = Field( + default=None, + description="Text content for chat responses or error messages", + ) + chat_id: str | None = Field( + default=None, + description="ID of the Chat message (for CHAT_RESPONSE messages)", + ) + chat_thread_id: str | None = Field( + default=None, + description="ID of the ChatThread this message belongs to", + ) + tool_invocation: PendingToolInvocation | None = Field( + default=None, + description="Tool invocation details for request/result messages", + ) + tool_result: dict[str, Any] | None = Field( + default=None, + description="Result data from tool execution", + ) + error: str | None = Field( + default=None, + description="Error message if type is ERROR", + ) + + +class LLMToolCall(BaseModel): + """ + A tool call requested by the LLM. + """ + tool_name: str = Field( + ..., + description="Name of the tool to invoke", + ) + tool_arguments: dict[str, Any] = Field( + ..., + description="Arguments to pass to the tool", + ) + + +class LLMChatResponse(BaseModel): + """ + Response from an LLM chat completion with tool support. + """ + content: str | None = Field( + default=None, + description="Text content of the response", + ) + tool_call: LLMToolCall | None = Field( + default=None, + description="Tool call requested by the LLM, if any", + ) + response_id: str | None = Field( + default=None, + description="Response ID for chaining (Responses API)", + ) + reasoning_content: str | None = Field( + default=None, + description="Extended reasoning content", + ) + + +class Chat(BaseModel): + """ + A single message in a conversation. + """ + id: str = Field( + default_factory=lambda: str(uuid4()), + description="Unique message ID (UUIDv4)", + ) + thread_id: str = Field( + ..., + description="ID of the parent thread this message belongs to", + ) + role: ChatRole = Field( + ..., + description="The role of the message sender (user, assistant, system, tool)", + ) + content: str = Field( + ..., + description="The content of the message", + ) + + +class ChatThread(BaseModel): + """ + Container for a conversation thread. + """ + id: str = Field( + default_factory=lambda: str(uuid4()), + description="Unique thread ID (UUIDv4)", + ) + message_ids: list[str] = Field( + default_factory=list, + description="Ordered list of message IDs in this thread", + ) + pending_tool_invocation: PendingToolInvocation | None = Field( + default=None, + description="Tool invocation awaiting user confirmation, if any", + ) + updated_at: int = Field( + default=0, + description="Unix timestamp (seconds) when thread was last updated", + ) diff --git a/web_hacker/data_models/llms/vendors.py b/web_hacker/data_models/llms/vendors.py new file mode 100644 index 00000000..fb00c197 --- /dev/null +++ b/web_hacker/data_models/llms/vendors.py @@ -0,0 +1,61 @@ +""" +web_hacker/data_models/llms/vendors.py + +This module contains the LLM vendor models. +""" + +from enum import StrEnum + + +class LLMVendor(StrEnum): + """ + Represents the vendor of an LLM. + """ + OPENAI = "openai" + #TODO:ANTHROPIC = "anthropic" + + +class OpenAIAPIType(StrEnum): + """ + OpenAI API type. + """ + CHAT_COMPLETIONS = "chat_completions" + RESPONSES = "responses" + + +class OpenAIModel(StrEnum): + """ + OpenAI models. + """ + GPT_5 = "gpt-5" + GPT_5_1 = "gpt-5.1" + GPT_5_2 = "gpt-5.2" + GPT_5_MINI = "gpt-5-mini" + GPT_5_NANO = "gpt-5-nano" + + +# class AnthropicModel(StrEnum): +# """Anthropic models.""" +# CLAUDE_OPUS_4_5 = "claude-opus-4-5-20251101" +# CLAUDE_SONNET_4_5 = "claude-sonnet-4-5-20250929" +# CLAUDE_HAIKU_4_5 = "claude-haiku-4-5-20251001" + + +# LLMModel type; only OpenAI models for now +type LLMModel = OpenAIModel + + +# Build model to vendor lookup from OpenAI models only +_model_to_vendor: dict[str, LLMVendor] = {} +_all_models: dict[str, str] = {} + +for model in OpenAIModel: + _model_to_vendor[model.value] = LLMVendor.OPENAI + _all_models[model.name] = model.value + + +def get_model_vendor(model: LLMModel) -> LLMVendor: + """ + Returns the vendor of the LLM model. + """ + return _model_to_vendor[model.value] diff --git a/web_hacker/llms/__init__.py b/web_hacker/llms/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/web_hacker/llms/abstract_llm_vendor_client.py b/web_hacker/llms/abstract_llm_vendor_client.py new file mode 100644 index 00000000..b15d867d --- /dev/null +++ b/web_hacker/llms/abstract_llm_vendor_client.py @@ -0,0 +1,171 @@ +""" +web_hacker/llms/abstract_llm_vendor_client.py + +Abstract base class for LLM vendor clients. +""" + +from abc import ABC, abstractmethod +from collections.abc import Generator +from typing import Any, ClassVar, TypeVar + +from pydantic import BaseModel + +from web_hacker.data_models.llms.interaction import LLMChatResponse +from web_hacker.data_models.llms.vendors import OpenAIModel + + +T = TypeVar("T", bound=BaseModel) + + +class AbstractLLMVendorClient(ABC): + """ + Abstract base class defining the interface for LLM vendor clients. + + All vendor-specific clients must implement this interface to ensure + consistent behavior across the LLMClient. + """ + + # Class attributes ____________________________________________________________________________________________________ + + DEFAULT_MAX_TOKENS: ClassVar[int] = 4_096 + DEFAULT_TEMPERATURE: ClassVar[float] = 0.7 + DEFAULT_STRUCTURED_TEMPERATURE: ClassVar[float] = 0.0 + + # Magic methods ________________________________________________________________________________________________________ + + def __init__(self, model: OpenAIModel) -> None: + """ + Initialize the vendor client. + + Args: + model: The LLM model to use. + """ + self.model = model + self._tools: list[dict[str, Any]] = [] + + # Tool management ______________________________________________________________________________________________________ + + @abstractmethod + def register_tool( + self, + name: str, + description: str, + parameters: dict[str, Any], + ) -> None: + """ + Register a tool for function calling. + + Args: + name: The name of the tool/function. + description: Description of what the tool does. + parameters: JSON Schema describing the tool's parameters. + """ + pass + + def clear_tools(self) -> None: + """Clear all registered tools.""" + self._tools = [] + + @property + def tools(self) -> list[dict[str, Any]]: + """Return the list of registered tools.""" + return self._tools + + # Unified API methods __________________________________________________________________________________________________ + + @abstractmethod + def call_sync( + self, + messages: list[dict[str, str]] | None = None, + input: str | None = None, + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + response_model: type[T] | None = None, + extended_reasoning: bool = False, + stateful: bool = False, + previous_response_id: str | None = None, + ) -> LLMChatResponse | T: + """ + Unified sync call to the LLM. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + input: Input string (shorthand for simple prompts). + system_prompt: Optional system prompt for context. + max_tokens: Maximum tokens in the response. + temperature: Sampling temperature (0.0-1.0). + response_model: Pydantic model class for structured response. + extended_reasoning: Enable extended reasoning (if supported). + stateful: Enable stateful conversation (if supported). + previous_response_id: Previous response ID for chaining (if supported). + + Returns: + LLMChatResponse or parsed Pydantic model if response_model is provided. + """ + pass + + @abstractmethod + async def call_async( + self, + messages: list[dict[str, str]] | None = None, + input: str | None = None, + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + response_model: type[T] | None = None, + extended_reasoning: bool = False, + stateful: bool = False, + previous_response_id: str | None = None, + ) -> LLMChatResponse | T: + """ + Unified async call to the LLM. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + input: Input string (shorthand for simple prompts). + system_prompt: Optional system prompt for context. + max_tokens: Maximum tokens in the response. + temperature: Sampling temperature (0.0-1.0). + response_model: Pydantic model class for structured response. + extended_reasoning: Enable extended reasoning (if supported). + stateful: Enable stateful conversation (if supported). + previous_response_id: Previous response ID for chaining (if supported). + + Returns: + LLMChatResponse or parsed Pydantic model if response_model is provided. + """ + pass + + @abstractmethod + def call_stream_sync( + self, + messages: list[dict[str, str]] | None = None, + input: str | None = None, + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + extended_reasoning: bool = False, + stateful: bool = False, + previous_response_id: str | None = None, + ) -> Generator[str | LLMChatResponse, None, None]: + """ + Unified streaming call to the LLM. + + Yields text chunks as they arrive, then yields the final LLMChatResponse. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + input: Input string (shorthand for simple prompts). + system_prompt: Optional system prompt for context. + max_tokens: Maximum tokens in the response. + temperature: Sampling temperature (0.0-1.0). + extended_reasoning: Enable extended reasoning (if supported). + stateful: Enable stateful conversation (if supported). + previous_response_id: Previous response ID for chaining (if supported). + + Yields: + str: Text chunks as they arrive. + LLMChatResponse: Final response with complete content and optional tool call. + """ + pass diff --git a/web_hacker/llms/llm_client.py b/web_hacker/llms/llm_client.py new file mode 100644 index 00000000..372fd1a8 --- /dev/null +++ b/web_hacker/llms/llm_client.py @@ -0,0 +1,221 @@ +""" +web_hacker/llms/llm_client.py + +Unified LLM client supporting OpenAI models. +""" + +from collections.abc import Generator +from typing import Any, Callable, TypeVar + +from pydantic import BaseModel + +from web_hacker.data_models.llms.interaction import LLMChatResponse +from web_hacker.data_models.llms.vendors import OpenAIAPIType, OpenAIModel +from web_hacker.llms.openai_client import OpenAIClient +from web_hacker.llms.tools.tool_utils import extract_description_from_docstring, generate_parameters_schema +from web_hacker.utils.logger import get_logger + +logger = get_logger(name=__name__) + + +T = TypeVar("T", bound=BaseModel) + + +class LLMClient: + """ + Unified LLM client class for interacting with OpenAI APIs. + + This is a facade that delegates to OpenAIClient and provides + a convenient interface for tool registration. + + Supports: + - Sync and async API calls + - Streaming responses + - Structured responses using Pydantic models + - Tool/function registration + - Both Chat Completions and Responses APIs + """ + + # Magic methods ________________________________________________________________________________________________________ + + def __init__(self, llm_model: OpenAIModel) -> None: + self.llm_model = llm_model + self._client = OpenAIClient(model=llm_model) + logger.info("Instantiated LLMClient with model: %s", llm_model) + + # Public methods _______________________________________________________________________________________________________ + + ## Tools + + def register_tool( + self, + name: str, + description: str, + parameters: dict[str, Any], + ) -> None: + """ + Register a tool for function calling. + + Args: + name: The name of the tool/function. + description: Description of what the tool does. + parameters: JSON Schema describing the tool's parameters. + """ + logger.debug("Registering tool %s (description: %s) with parameters: %s", name, description, parameters) + self._client.register_tool(name, description, parameters) + + def register_tool_from_function(self, func: Callable[..., Any]) -> None: + """ + Register a tool from a Python function, extracting metadata automatically. + + Extracts: + - name: from func.__name__ + - description: from the docstring (first paragraph) + - parameters: JSON Schema generated from type hints via pydantic + + Args: + func: The function to register as a tool. Must have type hints. + """ + name = func.__name__ + description = extract_description_from_docstring(func.__doc__) + parameters = generate_parameters_schema(func) + self.register_tool(name, description, parameters) + + def clear_tools(self) -> None: + """Clear all registered tools.""" + self._client.clear_tools() + logger.debug("Cleared all registered tools") + + ## Unified API methods + + def call_sync( + self, + messages: list[dict[str, str]] | None = None, + input: str | None = None, + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + response_model: type[T] | None = None, + extended_reasoning: bool = False, + stateful: bool = False, + previous_response_id: str | None = None, + api_type: OpenAIAPIType | None = None, + ) -> LLMChatResponse | T: + """ + Unified sync call to OpenAI. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + input: Input string (Responses API shorthand). + system_prompt: Optional system prompt for context. + max_tokens: Maximum tokens in the response. + temperature: Sampling temperature (0.0-1.0). + response_model: Pydantic model class for structured response. + extended_reasoning: Enable extended reasoning (Responses API only). + stateful: Enable stateful conversation (Responses API only). + previous_response_id: Previous response ID for chaining (Responses API only). + api_type: Explicit API type, or None for auto-resolution. + + Returns: + LLMChatResponse or parsed Pydantic model if response_model is provided. + """ + return self._client.call_sync( + messages=messages, + input=input, + system_prompt=system_prompt, + max_tokens=max_tokens, + temperature=temperature, + response_model=response_model, + extended_reasoning=extended_reasoning, + stateful=stateful, + previous_response_id=previous_response_id, + api_type=api_type, + ) + + async def call_async( + self, + messages: list[dict[str, str]] | None = None, + input: str | None = None, + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + response_model: type[T] | None = None, + extended_reasoning: bool = False, + stateful: bool = False, + previous_response_id: str | None = None, + api_type: OpenAIAPIType | None = None, + ) -> LLMChatResponse | T: + """ + Unified async call to OpenAI. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + input: Input string (Responses API shorthand). + system_prompt: Optional system prompt for context. + max_tokens: Maximum tokens in the response. + temperature: Sampling temperature (0.0-1.0). + response_model: Pydantic model class for structured response. + extended_reasoning: Enable extended reasoning (Responses API only). + stateful: Enable stateful conversation (Responses API only). + previous_response_id: Previous response ID for chaining (Responses API only). + api_type: Explicit API type, or None for auto-resolution. + + Returns: + LLMChatResponse or parsed Pydantic model if response_model is provided. + """ + return await self._client.call_async( + messages=messages, + input=input, + system_prompt=system_prompt, + max_tokens=max_tokens, + temperature=temperature, + response_model=response_model, + extended_reasoning=extended_reasoning, + stateful=stateful, + previous_response_id=previous_response_id, + api_type=api_type, + ) + + def call_stream_sync( + self, + messages: list[dict[str, str]] | None = None, + input: str | None = None, + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + extended_reasoning: bool = False, + stateful: bool = False, + previous_response_id: str | None = None, + api_type: OpenAIAPIType | None = None, + ) -> Generator[str | LLMChatResponse, None, None]: + """ + Unified streaming call to OpenAI. + + Yields text chunks as they arrive, then yields the final LLMChatResponse. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + input: Input string (Responses API shorthand). + system_prompt: Optional system prompt for context. + max_tokens: Maximum tokens in the response. + temperature: Sampling temperature (0.0-1.0). + extended_reasoning: Enable extended reasoning (Responses API only). + stateful: Enable stateful conversation (Responses API only). + previous_response_id: Previous response ID for chaining (Responses API only). + api_type: Explicit API type, or None for auto-resolution. + + Yields: + str: Text chunks as they arrive. + LLMChatResponse: Final response with complete content and optional tool call. + """ + yield from self._client.call_stream_sync( + messages=messages, + input=input, + system_prompt=system_prompt, + max_tokens=max_tokens, + temperature=temperature, + extended_reasoning=extended_reasoning, + stateful=stateful, + previous_response_id=previous_response_id, + api_type=api_type, + ) diff --git a/web_hacker/llms/openai_client.py b/web_hacker/llms/openai_client.py new file mode 100644 index 00000000..b4d8d62b --- /dev/null +++ b/web_hacker/llms/openai_client.py @@ -0,0 +1,565 @@ +""" +web_hacker/llms/openai_client.py + +OpenAI-specific LLM client implementation with unified API supporting +both Chat Completions and Responses APIs. +""" + +import json +from collections.abc import Generator +from typing import Any, TypeVar + +from openai import AsyncOpenAI, OpenAI +from pydantic import BaseModel + +from web_hacker.config import Config +from web_hacker.data_models.llms.interaction import LLMChatResponse, LLMToolCall +from web_hacker.data_models.llms.vendors import OpenAIAPIType, OpenAIModel +from web_hacker.llms.abstract_llm_vendor_client import AbstractLLMVendorClient +from web_hacker.utils.logger import get_logger + +logger = get_logger(name=__name__) + + +T = TypeVar("T", bound=BaseModel) + + +class OpenAIClient(AbstractLLMVendorClient): + """ + OpenAI-specific LLM client with unified API. + + Supports both Chat Completions API and Responses API with automatic + API type resolution based on parameters. + """ + + # Magic methods ________________________________________________________________________________________________________ + + def __init__(self, model: OpenAIModel) -> None: + super().__init__(model) + self._client = OpenAI(api_key=Config.OPENAI_API_KEY) + self._async_client = AsyncOpenAI(api_key=Config.OPENAI_API_KEY) + logger.debug("Initialized OpenAIClient with model: %s", model) + + # Private methods ______________________________________________________________________________________________________ + + def _resolve_max_tokens(self, max_tokens: int | None) -> int: + """Resolve max_tokens, using default if None.""" + return max_tokens if max_tokens is not None else self.DEFAULT_MAX_TOKENS + + def _resolve_temperature( + self, + temperature: float | None, + structured: bool = False, + ) -> float: + """Resolve temperature, using appropriate default if None.""" + if temperature is not None: + return temperature + return self.DEFAULT_STRUCTURED_TEMPERATURE if structured else self.DEFAULT_TEMPERATURE + + def _prepend_system_prompt( + self, + messages: list[dict[str, str]], + system_prompt: str | None, + ) -> list[dict[str, str]]: + """Prepend system prompt to messages if provided.""" + if system_prompt: + return [{"role": "system", "content": system_prompt}] + messages + return messages + + def _validate_and_resolve_api_type( + self, + api_type: OpenAIAPIType | None, + extended_reasoning: bool, + previous_response_id: str | None, + ) -> OpenAIAPIType: + """ + Validate params and resolve API type. Raises ValueError for invalid combos. + + Args: + api_type: Explicit API type, or None for auto-resolution. + extended_reasoning: Whether extended reasoning is requested. + previous_response_id: Previous response ID for chaining. + + Returns: + The resolved API type. + + Raises: + ValueError: If incompatible parameters are combined. + """ + if extended_reasoning and api_type == OpenAIAPIType.CHAT_COMPLETIONS: + raise ValueError("extended_reasoning=True requires Responses API") + if previous_response_id and api_type == OpenAIAPIType.CHAT_COMPLETIONS: + raise ValueError("previous_response_id requires Responses API") + + # Auto-resolve + if api_type is None: + if extended_reasoning or previous_response_id: + resolved = OpenAIAPIType.RESPONSES + else: + resolved = OpenAIAPIType.CHAT_COMPLETIONS + logger.debug("Auto-resolved API type to: %s", resolved.value) + return resolved + + logger.debug("Using explicit API type: %s", api_type.value) + return api_type + + def _build_chat_completions_kwargs( + self, + messages: list[dict[str, str]], + system_prompt: str | None, + max_tokens: int | None, + response_model: type[T] | None, + stream: bool = False, + ) -> dict[str, Any]: + """Build kwargs for Chat Completions API call.""" + all_messages = self._prepend_system_prompt(messages, system_prompt) + + kwargs: dict[str, Any] = { + "model": self.model.value, + "messages": all_messages, + "max_completion_tokens": self._resolve_max_tokens(max_tokens), + } + + if stream: + kwargs["stream"] = True + + if self._tools and response_model is None: + kwargs["tools"] = self._tools + + return kwargs + + def _build_responses_api_kwargs( + self, + messages: list[dict[str, str]] | None, + input_text: str | None, + system_prompt: str | None, + max_tokens: int | None, + extended_reasoning: bool, + previous_response_id: str | None, + response_model: type[T] | None, + stream: bool = False, + ) -> dict[str, Any]: + """Build kwargs for Responses API call.""" + kwargs: dict[str, Any] = { + "model": self.model.value, + "max_output_tokens": self._resolve_max_tokens(max_tokens), + } + + # Handle input: either input string or messages array + if previous_response_id: + kwargs["previous_response_id"] = previous_response_id + # When chaining, input is the new user message + if input_text: + kwargs["input"] = input_text + elif messages: + # Use the last user message as input for chaining + user_messages = [m for m in messages if m.get("role") == "user"] + if user_messages: + kwargs["input"] = user_messages[-1]["content"] + elif input_text: + kwargs["input"] = input_text + elif messages: + # Convert messages to Responses API format + all_messages = self._prepend_system_prompt(messages, system_prompt) + kwargs["input"] = all_messages + else: + raise ValueError("Either messages or input must be provided") + + # Add system instructions if provided and not using messages + if system_prompt and input_text and not messages: + kwargs["instructions"] = system_prompt + + if stream: + kwargs["stream"] = True + + if extended_reasoning: + kwargs["reasoning"] = {"effort": "medium"} + + if self._tools and response_model is None: + kwargs["tools"] = self._tools + + return kwargs + + def _parse_chat_completions_response( + self, + response: Any, + response_model: type[T] | None, + ) -> LLMChatResponse | T: + """Parse response from Chat Completions API.""" + message = response.choices[0].message + + # Handle structured response + if response_model is not None: + parsed = getattr(message, "parsed", None) + if parsed is None: + raise ValueError("Failed to parse structured response from OpenAI") + return parsed + + # Extract tool call if present + tool_call: LLMToolCall | None = None + if message.tool_calls and len(message.tool_calls) > 0: + tc = message.tool_calls[0] + tool_call = LLMToolCall( + tool_name=tc.function.name, + tool_arguments=json.loads(tc.function.arguments), + ) + + return LLMChatResponse( + content=message.content, + tool_call=tool_call, + ) + + def _parse_responses_api_response( + self, + response: Any, + response_model: type[T] | None, + ) -> LLMChatResponse | T: + """Parse response from Responses API.""" + # Handle structured response + if response_model is not None: + # Responses API returns structured output differently + output = response.output + if output and len(output) > 0: + for item in output: + if hasattr(item, "content") and item.content: + for content_block in item.content: + if hasattr(content_block, "parsed") and content_block.parsed: + return content_block.parsed + raise ValueError("Failed to parse structured response from OpenAI Responses API") + + # Extract content and tool calls + content: str | None = None + tool_call: LLMToolCall | None = None + reasoning_content: str | None = None + + output = response.output + if output: + for item in output: + # Handle reasoning content + if item.type == "reasoning": + if hasattr(item, "summary") and item.summary: + reasoning_parts = [] + for summary_item in item.summary: + if hasattr(summary_item, "text"): + reasoning_parts.append(summary_item.text) + if reasoning_parts: + reasoning_content = "".join(reasoning_parts) + + # Handle message content + if item.type == "message": + if hasattr(item, "content") and item.content: + text_parts = [] + for content_block in item.content: + if content_block.type == "output_text": + text_parts.append(content_block.text) + if text_parts: + content = "".join(text_parts) + + # Handle function calls + if item.type == "function_call": + tool_call = LLMToolCall( + tool_name=item.name, + tool_arguments=json.loads(item.arguments) if isinstance(item.arguments, str) else item.arguments, + ) + + return LLMChatResponse( + content=content, + tool_call=tool_call, + response_id=response.id, + reasoning_content=reasoning_content, + ) + + # Public methods _______________________________________________________________________________________________________ + + ## Tool management + + def register_tool( + self, + name: str, + description: str, + parameters: dict[str, Any], + ) -> None: + """Register a tool in OpenAI's function calling format.""" + logger.debug("Registering OpenAI tool: %s", name) + self._tools.append({ + "type": "function", + "function": { + "name": name, + "description": description, + "parameters": parameters, + } + }) + + ## Unified API methods + + def call_sync( + self, + messages: list[dict[str, str]] | None = None, + input: str | None = None, + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, # noqa: ARG002 - reserved for future use + response_model: type[T] | None = None, + extended_reasoning: bool = False, + stateful: bool = False, # noqa: ARG002 - reserved for future use + previous_response_id: str | None = None, + api_type: OpenAIAPIType | None = None, + ) -> LLMChatResponse | T: + """ + Unified sync call to OpenAI. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + input: Input string (Responses API shorthand). + system_prompt: Optional system prompt for context. + max_tokens: Maximum tokens in the response. + temperature: Sampling temperature (0.0-1.0). + response_model: Pydantic model class for structured response. + extended_reasoning: Enable extended reasoning (Responses API only). + stateful: Enable stateful conversation (Responses API only). + previous_response_id: Previous response ID for chaining (Responses API only). + api_type: Explicit API type, or None for auto-resolution. + + Returns: + LLMChatResponse or parsed Pydantic model if response_model is provided. + + Raises: + ValueError: If incompatible parameters are combined. + """ + resolved_api_type = self._validate_and_resolve_api_type( + api_type, extended_reasoning, previous_response_id + ) + + if resolved_api_type == OpenAIAPIType.CHAT_COMPLETIONS: + if messages is None: + raise ValueError("messages is required for Chat Completions API") + + if response_model is not None: + # Use beta.chat.completions.parse for structured output + kwargs = self._build_chat_completions_kwargs( + messages, system_prompt, max_tokens, response_model + ) + response = self._client.beta.chat.completions.parse( + **kwargs, + response_format=response_model, + ) + else: + kwargs = self._build_chat_completions_kwargs( + messages, system_prompt, max_tokens, response_model + ) + response = self._client.chat.completions.create(**kwargs) + + return self._parse_chat_completions_response(response, response_model) + + else: # Responses API + kwargs = self._build_responses_api_kwargs( + messages, input, system_prompt, max_tokens, + extended_reasoning, previous_response_id, response_model + ) + + if response_model is not None: + # Add structured output format + kwargs["text"] = {"format": {"type": "json_schema", "schema": response_model.model_json_schema()}} + + response = self._client.responses.create(**kwargs) + return self._parse_responses_api_response(response, response_model) + + async def call_async( + self, + messages: list[dict[str, str]] | None = None, + input: str | None = None, + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, # noqa: ARG002 - reserved for future use + response_model: type[T] | None = None, + extended_reasoning: bool = False, + stateful: bool = False, # noqa: ARG002 - reserved for future use + previous_response_id: str | None = None, + api_type: OpenAIAPIType | None = None, + ) -> LLMChatResponse | T: + """ + Unified async call to OpenAI. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + input: Input string (Responses API shorthand). + system_prompt: Optional system prompt for context. + max_tokens: Maximum tokens in the response. + temperature: Sampling temperature (0.0-1.0). + response_model: Pydantic model class for structured response. + extended_reasoning: Enable extended reasoning (Responses API only). + stateful: Enable stateful conversation (Responses API only). + previous_response_id: Previous response ID for chaining (Responses API only). + api_type: Explicit API type, or None for auto-resolution. + + Returns: + LLMChatResponse or parsed Pydantic model if response_model is provided. + + Raises: + ValueError: If incompatible parameters are combined. + """ + resolved_api_type = self._validate_and_resolve_api_type( + api_type, extended_reasoning, previous_response_id + ) + + if resolved_api_type == OpenAIAPIType.CHAT_COMPLETIONS: + if messages is None: + raise ValueError("messages is required for Chat Completions API") + + if response_model is not None: + kwargs = self._build_chat_completions_kwargs( + messages, system_prompt, max_tokens, response_model + ) + response = await self._async_client.beta.chat.completions.parse( + **kwargs, + response_format=response_model, + ) + else: + kwargs = self._build_chat_completions_kwargs( + messages, system_prompt, max_tokens, response_model + ) + response = await self._async_client.chat.completions.create(**kwargs) + + return self._parse_chat_completions_response(response, response_model) + + else: # Responses API + kwargs = self._build_responses_api_kwargs( + messages, input, system_prompt, max_tokens, + extended_reasoning, previous_response_id, response_model + ) + + if response_model is not None: + kwargs["text"] = {"format": {"type": "json_schema", "schema": response_model.model_json_schema()}} + + response = await self._async_client.responses.create(**kwargs) + return self._parse_responses_api_response(response, response_model) + + def call_stream_sync( + self, + messages: list[dict[str, str]] | None = None, + input: str | None = None, + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, # noqa: ARG002 - reserved for future use + extended_reasoning: bool = False, + stateful: bool = False, # noqa: ARG002 - reserved for future use + previous_response_id: str | None = None, + api_type: OpenAIAPIType | None = None, + ) -> Generator[str | LLMChatResponse, None, None]: + """ + Unified streaming call to OpenAI. + + Yields text chunks as they arrive, then yields the final LLMChatResponse. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + input: Input string (Responses API shorthand). + system_prompt: Optional system prompt for context. + max_tokens: Maximum tokens in the response. + temperature: Sampling temperature (0.0-1.0). + extended_reasoning: Enable extended reasoning (Responses API only). + stateful: Enable stateful conversation (Responses API only). + previous_response_id: Previous response ID for chaining (Responses API only). + api_type: Explicit API type, or None for auto-resolution. + + Yields: + str: Text chunks as they arrive. + LLMChatResponse: Final response with complete content and optional tool call. + """ + resolved_api_type = self._validate_and_resolve_api_type( + api_type, extended_reasoning, previous_response_id + ) + + if resolved_api_type == OpenAIAPIType.CHAT_COMPLETIONS: + if messages is None: + raise ValueError("messages is required for Chat Completions API") + + kwargs = self._build_chat_completions_kwargs( + messages, system_prompt, max_tokens, response_model=None, stream=True + ) + stream = self._client.chat.completions.create(**kwargs) + + # Accumulate content and tool call data + full_content: list[str] = [] + tool_call_name: str | None = None + tool_call_args: list[str] = [] + + for chunk in stream: + delta = chunk.choices[0].delta if chunk.choices else None + if delta is None: + continue + + # Handle text content + if delta.content: + full_content.append(delta.content) + yield delta.content + + # Handle tool calls (streamed in chunks) + if delta.tool_calls: + for tc in delta.tool_calls: + if tc.function: + if tc.function.name: + tool_call_name = tc.function.name + if tc.function.arguments: + tool_call_args.append(tc.function.arguments) + + # Build final response + tool_call: LLMToolCall | None = None + if tool_call_name: + tool_call = LLMToolCall( + tool_name=tool_call_name, + tool_arguments=json.loads("".join(tool_call_args)) if tool_call_args else {}, + ) + + yield LLMChatResponse( + content="".join(full_content) if full_content else None, + tool_call=tool_call, + ) + + else: # Responses API streaming + kwargs = self._build_responses_api_kwargs( + messages, input, system_prompt, max_tokens, + extended_reasoning, previous_response_id, response_model=None, stream=True + ) + + stream = self._client.responses.create(**kwargs) + + full_content: list[str] = [] + tool_call_name: str | None = None + tool_call_args: list[str] = [] + reasoning_content: str | None = None + response_id: str | None = None + + for event in stream: + # Handle different event types from Responses API streaming + if hasattr(event, "type"): + if event.type == "response.created": + response_id = event.response.id + + elif event.type == "response.output_text.delta": + if hasattr(event, "delta"): + full_content.append(event.delta) + yield event.delta + + elif event.type == "response.function_call_arguments.delta": + if hasattr(event, "delta"): + tool_call_args.append(event.delta) + + elif event.type == "response.output_item.added": + if hasattr(event, "item") and event.item.type == "function_call": + tool_call_name = event.item.name + + # Build final response + tool_call: LLMToolCall | None = None + if tool_call_name: + tool_call = LLMToolCall( + tool_name=tool_call_name, + tool_arguments=json.loads("".join(tool_call_args)) if tool_call_args else {}, + ) + + yield LLMChatResponse( + content="".join(full_content) if full_content else None, + tool_call=tool_call, + response_id=response_id, + reasoning_content=reasoning_content, + ) diff --git a/web_hacker/llms/tools/__init__.py b/web_hacker/llms/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/web_hacker/llms/tools/guide_agent_tools.py b/web_hacker/llms/tools/guide_agent_tools.py new file mode 100644 index 00000000..6cba95c0 --- /dev/null +++ b/web_hacker/llms/tools/guide_agent_tools.py @@ -0,0 +1,39 @@ +""" +web_hacker/llms/tools/guide_agent_tools.py + +Tool functions for the guide agent. +""" + +from typing import Any + + +def start_routine_discovery_job_creation( + task_description: str, + expected_output_description: str, + input_parameters: list[dict[str, str]] | None = None, + filters_or_constraints: list[str] | None = None, + target_website: str | None = None, +) -> dict[str, Any]: + """ + Initiates the routine discovery process. + + Call this when you have gathered enough information about: + 1) What task the user wants to automate + 2) What data/output they expect + 3) What input parameters the routine should accept + 4) Any filters or constraints + + This tool requests user confirmation before executing. + + Args: + task_description: Description of the task/routine the user wants to create + expected_output_description: Description of what data the routine should return + input_parameters: List of input parameters with 'name' and 'description' keys + filters_or_constraints: Any filters or constraints the user mentioned + target_website: Target website/URL if mentioned by user + + Returns: + Result dict to be passed to routine discovery agent + """ + # TODO: implement the actual handoff logic + raise NotImplementedError("start_routine_discovery_job_creation not yet implemented") diff --git a/web_hacker/llms/tools/tool_utils.py b/web_hacker/llms/tools/tool_utils.py new file mode 100644 index 00000000..fcc4fe36 --- /dev/null +++ b/web_hacker/llms/tools/tool_utils.py @@ -0,0 +1,74 @@ +""" +web_hacker/llms/tools/tool_utils.py + +Utilities for converting Python functions to LLM tool definitions. +""" + +import inspect +from typing import Any, Callable, get_type_hints + +from pydantic import TypeAdapter + + +def extract_description_from_docstring(docstring: str | None) -> str: + """ + Extract the first paragraph from a docstring as the description. + + Args: + docstring: The function's docstring (func.__doc__) + + Returns: + The first paragraph of the docstring, or empty string if none. + """ + if not docstring: + return "" + + # split on double newlines to get first paragraph + paragraphs = docstring.strip().split("\n\n") + if not paragraphs: + return "" + + # clean up the first paragraph (remove leading/trailing whitespace from each line) + first_para = paragraphs[0] + lines = [line.strip() for line in first_para.split("\n")] + return " ".join(lines) + + +def generate_parameters_schema(func: Callable[..., Any]) -> dict[str, Any]: + """ + Generate JSON Schema for function parameters using pydantic. + + Args: + func: The function to generate schema for. Must have type hints. + + Returns: + JSON Schema dict with 'type', 'properties', and 'required' keys. + """ + sig = inspect.signature(obj=func) + hints = get_type_hints(obj=func) + + properties: dict[str, Any] = {} + required: list[str] = [] + + for param_name, param in sig.parameters.items(): + if param_name in ("self", "cls"): + continue + + param_type = hints.get(param_name, Any) + # use pydantic TypeAdapter to generate schema for this type + schema = TypeAdapter(param_type).json_schema() + + # remove pydantic metadata that's not needed for tool schemas + schema.pop("title", None) + + properties[param_name] = schema + + # parameter is required if it has no default value + if param.default is inspect.Parameter.empty: + required.append(param_name) + + return { + "type": "object", + "properties": properties, + "required": required, + } diff --git a/web_hacker/routine_discovery/context_manager.py b/web_hacker/routine_discovery/context_manager.py index 73f90ac5..2b16b4f2 100644 --- a/web_hacker/routine_discovery/context_manager.py +++ b/web_hacker/routine_discovery/context_manager.py @@ -1,10 +1,17 @@ -from pydantic import BaseModel, field_validator, Field, ConfigDict -from openai import OpenAI -from abc import ABC, abstractmethod -import os +""" +web_hacker/routine_discovery/context_manager.py + +Context management abstractions and utilities for routine discovery. +""" + import json -import time +import os import shutil +import time +from abc import ABC, abstractmethod + +from openai import OpenAI +from pydantic import BaseModel, ConfigDict, Field, field_validator from web_hacker.utils.data_utils import get_text_from_html diff --git a/web_hacker/utils/exceptions.py b/web_hacker/utils/exceptions.py index ed299f6e..8b823432 100644 --- a/web_hacker/utils/exceptions.py +++ b/web_hacker/utils/exceptions.py @@ -81,3 +81,9 @@ class WebHackerError(Exception): """ Base exception for all Web Hacker errors. """ + + +class UnknownToolError(Exception): + """ + Raised when attempting to execute a tool that does not exist. + """