diff --git a/pyproject.toml b/pyproject.toml index 4280aba..4608af6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,9 @@ classifiers = [ "Topic :: Software Development :: Testing", ] dependencies = [ + "anthropic>=0.76.0", "openai>=2.6.1", + "jmespath>=1.0.1,<2", "pydantic>=2.11.4", "python-dotenv>=1.2.1", "requests>=2.31.0", @@ -46,7 +48,6 @@ dependencies = [ "beautifulsoup4>=4.14.2", "pylint>=3.0.0", "pytest>=8.3.5", - "jmespath>=1.0.1,<2", ] [project.optional-dependencies] @@ -60,7 +61,7 @@ web-hacker-discover = "web_hacker.scripts.discover_routine:main" web-hacker-execute = "web_hacker.scripts.execute_routine:main" [project.urls] -Homepage = "https://www.vectorly.app" +Homepage = "https://vectorly.app" Documentation = "https://github.com/VectorlyApp/web-hacker#readme" Repository = "https://github.com/VectorlyApp/web-hacker" Issues = "https://github.com/VectorlyApp/web-hacker/issues" diff --git a/scripts/run_guide_agent.py b/scripts/run_guide_agent.py new file mode 100755 index 0000000..75d7874 --- /dev/null +++ b/scripts/run_guide_agent.py @@ -0,0 +1,347 @@ +#!/usr/bin/env python3 +""" +scripts/run_guide_agent.py + +Interactive terminal interface for the Guide Agent. +Guides users through creating web automation routines. +""" + +import json +import sys +import textwrap +from typing import Any + +from web_hacker.agents.guide_agent.guide_agent import GuideAgent +from web_hacker.data_models.chat import ( + ChatMessageType, + EmittedChatMessage, + PendingToolInvocation, + ToolInvocationStatus, +) + + +# ANSI color codes +class Colors: + """ANSI escape codes for terminal colors.""" + + RESET = "\033[0m" + BOLD = "\033[1m" + DIM = "\033[2m" + ITALIC = "\033[3m" + UNDERLINE = "\033[4m" + + # Foreground colors + BLACK = "\033[30m" + RED = "\033[31m" + GREEN = "\033[32m" + YELLOW = "\033[33m" + BLUE = "\033[34m" + MAGENTA = "\033[35m" + CYAN = "\033[36m" + WHITE = "\033[37m" + + # Bright foreground + BRIGHT_BLACK = "\033[90m" + BRIGHT_RED = "\033[91m" + BRIGHT_GREEN = "\033[92m" + BRIGHT_YELLOW = "\033[93m" + BRIGHT_BLUE = "\033[94m" + BRIGHT_MAGENTA = "\033[95m" + BRIGHT_CYAN = "\033[96m" + BRIGHT_WHITE = "\033[97m" + + # Background colors + BG_BLACK = "\033[40m" + BG_RED = "\033[41m" + BG_GREEN = "\033[42m" + BG_YELLOW = "\033[43m" + BG_BLUE = "\033[44m" + BG_MAGENTA = "\033[45m" + BG_CYAN = "\033[46m" + BG_WHITE = "\033[47m" + + +def colorize(text: str, *codes: str) -> str: + """Apply ANSI color codes to text.""" + return "".join(codes) + text + Colors.RESET + + +def print_wrapped(text: str, indent: str = " ", width: int = 80) -> None: + """Print text with word wrapping and indentation.""" + lines = text.split("\n") + for line in lines: + if line.strip(): + wrapped = textwrap.fill(line, width=width, initial_indent=indent, subsequent_indent=indent) + print(wrapped) + else: + print() + + +class TerminalGuideChat: + """Interactive terminal chat interface for the Guide Agent.""" + + BANNER = r""" + ╔══════════════════════════════════════════════════════════════════╗ + ║ ║ + ║ ██╗ ██╗███████╗ ██████╗████████╗ ██████╗ ██████╗ ██╗ ██╗ ██╗║ + ║ ██║ ██║██╔════╝██╔════╝╚══██╔══╝██╔═══██╗██╔══██╗██║ ╚██╗ ██╔╝║ + ║ ██║ ██║█████╗ ██║ ██║ ██║ ██║██████╔╝██║ ╚████╔╝ ║ + ║ ╚██╗ ██╔╝██╔══╝ ██║ ██║ ██║ ██║██╔══██╗██║ ╚██╔╝ ║ + ║ ╚████╔╝ ███████╗╚██████╗ ██║ ╚██████╔╝██║ ██║███████╗██║ ║ + ║ ╚═══╝ ╚══════╝ ╚═════╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝╚══════╝╚═╝ ║ + ║ ║ + ║ Guide Agent Terminal ║ + ║ ║ + ╚══════════════════════════════════════════════════════════════════╝ + """ + + WELCOME_MESSAGE = """ + Welcome! I'll help you create a web automation routine from your + CDP (Chrome DevTools Protocol) captures. + + I'll analyze your network transactions to identify relevant API + endpoints, required cookies, headers, and request patterns that + can be turned into a reusable routine. + + Commands: + • Type your message and press Enter to chat + • Type 'quit' or 'exit' to leave + • Type 'reset' to start a new conversation + + Links: + • Docs: https://vectorly.app/docs + • Console: https://console.vectorly.app + """ + + def __init__(self) -> None: + """Initialize the terminal chat interface.""" + self._pending_invocation: PendingToolInvocation | None = None + self._streaming_started: bool = False + self._agent = GuideAgent( + emit_message_callable=self._handle_message, + stream_chunk_callable=self._handle_stream_chunk, + ) + + def _handle_stream_chunk(self, chunk: str) -> None: + """ + Handle a streaming text chunk from the LLM. + + Args: + chunk: A text chunk from the streaming response. + """ + if not self._streaming_started: + # Print the header before the first chunk + print() + print(colorize(" Assistant", Colors.BOLD, Colors.CYAN) + colorize(":", Colors.DIM)) + print() + print(" ", end="", flush=True) + self._streaming_started = True + + # Print chunk without newline, flush immediately + print(chunk, end="", flush=True) + + def _handle_message(self, message: EmittedChatMessage) -> None: + """ + Handle messages emitted by the Guide Agent. + + Args: + message: The emitted message from the agent. + """ + if message.type == ChatMessageType.CHAT_RESPONSE: + # If we were streaming, just finish with newlines (content already printed) + if self._streaming_started: + print() # End the streamed line + print() # Add spacing + self._streaming_started = False + else: + self._print_assistant_message(message.content or "") + + elif message.type == ChatMessageType.TOOL_INVOCATION_REQUEST: + if message.tool_invocation: + self._pending_invocation = message.tool_invocation + self._print_tool_request(message.tool_invocation) + + elif message.type == ChatMessageType.TOOL_INVOCATION_RESULT: + if message.tool_invocation: + self._print_tool_result( + message.tool_invocation, + message.tool_result, + message.error, + ) + + elif message.type == ChatMessageType.ERROR: + self._print_error(message.error or "Unknown error") + + def _print_assistant_message(self, content: str) -> None: + """Print an assistant response.""" + print() + print(colorize(" Assistant", Colors.BOLD, Colors.CYAN) + colorize(":", Colors.DIM)) + print() + print_wrapped(content, indent=" ") + print() + + def _print_tool_request(self, invocation: PendingToolInvocation) -> None: + """Print a tool invocation request with formatted arguments.""" + print() + print(colorize(" ┌─────────────────────────────────────────────────────────────────┐", Colors.YELLOW)) + print(colorize(" │", Colors.YELLOW) + colorize(" TOOL INVOCATION REQUEST", Colors.BOLD, Colors.YELLOW) + colorize(" │", Colors.YELLOW)) + print(colorize(" ├─────────────────────────────────────────────────────────────────┤", Colors.YELLOW)) + print(colorize(" │", Colors.YELLOW)) + + # Tool name + print(colorize(" │ ", Colors.YELLOW) + colorize("Tool: ", Colors.DIM) + colorize(invocation.tool_name, Colors.BRIGHT_WHITE, Colors.BOLD)) + + # Arguments + print(colorize(" │", Colors.YELLOW)) + print(colorize(" │ ", Colors.YELLOW) + colorize("Arguments:", Colors.DIM)) + + args_json = json.dumps(invocation.tool_arguments, indent=4) + for line in args_json.split("\n"): + print(colorize(" │ ", Colors.YELLOW) + colorize(line, Colors.WHITE)) + + print(colorize(" │", Colors.YELLOW)) + print(colorize(" └─────────────────────────────────────────────────────────────────┘", Colors.YELLOW)) + print() + print(colorize(" Do you want to proceed? ", Colors.BRIGHT_YELLOW) + colorize("[y/n]", Colors.DIM) + ": ", end="") + + def _print_tool_result( + self, + invocation: PendingToolInvocation, + result: dict[str, Any] | None, + error: str | None, + ) -> None: + """Print a tool invocation result.""" + print() + + if invocation.status == ToolInvocationStatus.DENIED: + print(colorize(" ✗ Tool invocation denied", Colors.YELLOW)) + + elif invocation.status == ToolInvocationStatus.EXECUTED: + print(colorize(" ✓ Tool executed successfully", Colors.GREEN, Colors.BOLD)) + if result: + print() + print(colorize(" Result:", Colors.DIM)) + result_json = json.dumps(result, indent=4) + for line in result_json.split("\n"): + print(colorize(" " + line, Colors.GREEN)) + + elif invocation.status == ToolInvocationStatus.FAILED: + print(colorize(" ✗ Tool execution failed", Colors.RED, Colors.BOLD)) + if error: + print(colorize(f" Error: {error}", Colors.RED)) + + print() + + def _print_error(self, error: str) -> None: + """Print an error message.""" + print() + print(colorize(" ⚠ Error: ", Colors.RED, Colors.BOLD) + colorize(error, Colors.RED)) + print() + + def _print_user_prompt(self) -> None: + """Print the user input prompt.""" + print(colorize(" You", Colors.BOLD, Colors.GREEN) + colorize(": ", Colors.DIM), end="") + + def _handle_tool_confirmation(self, user_input: str) -> bool: + """ + Handle yes/no confirmation for pending tool invocation. + + Args: + user_input: The user's input. + + Returns: + True if the confirmation was handled, False otherwise. + """ + if not self._pending_invocation: + return False + + normalized = user_input.strip().lower() + + if normalized in ("y", "yes"): + invocation_id = self._pending_invocation.invocation_id + self._pending_invocation = None + self._agent.confirm_tool_invocation(invocation_id) + return True + + elif normalized in ("n", "no"): + invocation_id = self._pending_invocation.invocation_id + self._pending_invocation = None + self._agent.deny_tool_invocation(invocation_id, reason="User declined") + return True + + else: + print(colorize(" Please enter 'y' or 'n': ", Colors.YELLOW), end="") + return True # Still in confirmation mode + + def run(self) -> None: + """Run the interactive chat loop.""" + # Print banner and welcome + print(colorize(self.BANNER, Colors.BRIGHT_MAGENTA, Colors.BOLD)) + print(colorize(self.WELCOME_MESSAGE, Colors.DIM)) + print(colorize(" " + "─" * 67, Colors.DIM)) + print() + + while True: + try: + # Handle pending tool confirmation + if self._pending_invocation: + user_input = input() + if self._handle_tool_confirmation(user_input): + if not self._pending_invocation: + # Confirmation was processed, continue to next iteration + continue + else: + # Still waiting for valid y/n + continue + else: + self._print_user_prompt() + user_input = input() + + # Check for commands + normalized = user_input.strip().lower() + + if normalized in ("quit", "exit", "q"): + print() + print(colorize(" Goodbye! 👋", Colors.CYAN, Colors.BOLD)) + print() + break + + if normalized == "reset": + self._agent.reset() + self._pending_invocation = None + print() + print(colorize(" ↺ Conversation reset", Colors.YELLOW)) + print() + continue + + if not user_input.strip(): + continue + + # Process the message + self._agent.process_user_message(user_input) + + except KeyboardInterrupt: + print() + print(colorize("\n Interrupted. Goodbye! 👋", Colors.CYAN)) + print() + break + + except EOFError: + print() + print(colorize("\n Goodbye! 👋", Colors.CYAN)) + print() + break + + +def main() -> None: + """Entry point for the guide agent terminal.""" + try: + chat = TerminalGuideChat() + chat.run() + except Exception as e: + print(colorize(f"\n Fatal error: {e}", Colors.RED, Colors.BOLD), file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tests/unit/test_tool_utils.py b/tests/unit/test_tool_utils.py new file mode 100644 index 0000000..b5585e0 --- /dev/null +++ b/tests/unit/test_tool_utils.py @@ -0,0 +1,106 @@ +""" +tests/unit/test_tool_utils.py + +Unit tests for LLM tool utilities. +""" + +from web_hacker.llms.tools.tool_utils import ( + extract_description_from_docstring, + generate_parameters_schema, +) + + +class TestExtractDescriptionFromDocstring: + """Tests for docstring description extraction.""" + + def test_single_line_docstring(self) -> None: + docstring = "This is a simple description." + result = extract_description_from_docstring(docstring) + assert result == "This is a simple description." + + def test_multiline_first_paragraph(self) -> None: + docstring = """This is a description + that spans multiple lines.""" + result = extract_description_from_docstring(docstring) + assert result == "This is a description that spans multiple lines." + + def test_extracts_only_first_paragraph(self) -> None: + docstring = """First paragraph here. + + Args: + foo: Some argument. + """ + result = extract_description_from_docstring(docstring) + assert result == "First paragraph here." + + def test_none_docstring(self) -> None: + result = extract_description_from_docstring(None) + assert result == "" + + def test_empty_docstring(self) -> None: + result = extract_description_from_docstring("") + assert result == "" + + def test_strips_leading_whitespace(self) -> None: + docstring = """ + Description with leading whitespace. + """ + result = extract_description_from_docstring(docstring) + assert result == "Description with leading whitespace." + + +class TestGenerateParametersSchema: + """Tests for function parameter schema generation.""" + + def test_simple_string_params(self) -> None: + def example_func(name: str, value: str) -> None: + pass + + schema = generate_parameters_schema(example_func) + assert schema["type"] == "object" + assert schema["required"] == ["name", "value"] + assert schema["properties"]["name"]["type"] == "string" + assert schema["properties"]["value"]["type"] == "string" + + def test_optional_params_not_required(self) -> None: + def example_func(required_param: str, optional_param: str | None = None) -> None: + pass + + schema = generate_parameters_schema(example_func) + assert schema["required"] == ["required_param"] + assert "optional_param" in schema["properties"] + assert "optional_param" not in schema["required"] + + def test_list_type(self) -> None: + def example_func(items: list[str]) -> None: + pass + + schema = generate_parameters_schema(example_func) + assert schema["properties"]["items"]["type"] == "array" + assert schema["properties"]["items"]["items"]["type"] == "string" + + def test_dict_type(self) -> None: + def example_func(data: dict[str, int]) -> None: + pass + + schema = generate_parameters_schema(example_func) + props = schema["properties"]["data"] + assert props["type"] == "object" + assert props["additionalProperties"]["type"] == "integer" + + def test_skips_self_parameter(self) -> None: + class Example: + def method(self, name: str) -> None: + pass + + schema = generate_parameters_schema(Example.method) + assert "self" not in schema["properties"] + assert schema["required"] == ["name"] + + def test_nullable_type_uses_anyof(self) -> None: + def example_func(value: str | None) -> None: + pass + + schema = generate_parameters_schema(example_func) + # pydantic represents str | None as anyOf + assert "anyOf" in schema["properties"]["value"] diff --git a/uv.lock b/uv.lock index ddc6619..e0f644a 100644 --- a/uv.lock +++ b/uv.lock @@ -11,6 +11,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, ] +[[package]] +name = "anthropic" +version = "0.76.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "distro" }, + { name = "docstring-parser" }, + { name = "httpx" }, + { name = "jiter" }, + { name = "pydantic" }, + { name = "sniffio" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6e/be/d11abafaa15d6304826438170f7574d750218f49a106c54424a40cef4494/anthropic-0.76.0.tar.gz", hash = "sha256:e0cae6a368986d5cf6df743dfbb1b9519e6a9eee9c6c942ad8121c0b34416ffe", size = 495483, upload-time = "2026-01-13T18:41:14.908Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/70/7b0fd9c1a738f59d3babe2b4212031c34ab7d0fda4ffef15b58a55c5bcea/anthropic-0.76.0-py3-none-any.whl", hash = "sha256:81efa3113901192af2f0fe977d3ec73fdadb1e691586306c4256cd6d5ccc331c", size = 390309, upload-time = "2026-01-13T18:41:13.483Z" }, +] + [[package]] name = "anyio" version = "4.12.1" @@ -179,6 +198,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" }, ] +[[package]] +name = "docstring-parser" +version = "0.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b2/9d/c3b43da9515bd270df0f80548d9944e389870713cc1fe2b8fb35fe2bcefd/docstring_parser-0.17.0.tar.gz", hash = "sha256:583de4a309722b3315439bb31d64ba3eebada841f2e2cee23b99df001434c912", size = 27442, upload-time = "2025-07-21T07:35:01.868Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl", hash = "sha256:cf2569abd23dce8099b300f9b4fa8191e9582dda731fd533daf54c4551658708", size = 36896, upload-time = "2025-07-21T07:35:00.684Z" }, +] + [[package]] name = "executing" version = "2.2.1" @@ -905,9 +933,10 @@ wheels = [ [[package]] name = "web-hacker" -version = "1.2.2" +version = "1.2.3" source = { editable = "." } dependencies = [ + { name = "anthropic" }, { name = "beautifulsoup4" }, { name = "jmespath" }, { name = "openai" }, @@ -929,9 +958,10 @@ dev = [ [package.metadata] requires-dist = [ + { name = "anthropic", specifier = ">=0.76.0" }, { name = "beautifulsoup4", specifier = ">=4.14.2" }, { name = "ipykernel", marker = "extra == 'dev'", specifier = ">=6.29.5" }, - { name = "jmespath", specifier = ">=1.0.1" }, + { name = "jmespath", specifier = ">=1.0.1,<2" }, { name = "openai", specifier = ">=2.6.1" }, { name = "pydantic", specifier = ">=2.11.4" }, { name = "pylint", specifier = ">=3.0.0" }, diff --git a/web_hacker/agents/guide_agent/guide_agent.py b/web_hacker/agents/guide_agent/guide_agent.py new file mode 100644 index 0000000..030552b --- /dev/null +++ b/web_hacker/agents/guide_agent/guide_agent.py @@ -0,0 +1,552 @@ +""" +web_hacker/agents/guide_agent/guide_agent.py + +Guide agent that guides the user through the process of creating or editing a routine. +""" + +from datetime import datetime +from uuid import uuid4 +from typing import Any, Callable + +from web_hacker.data_models.chat import ( + Chat, + ChatThread, + ChatRole, + EmittedChatMessage, + ChatMessageType, + LLMChatResponse, + PendingToolInvocation, + ToolInvocationStatus, +) +from web_hacker.data_models.llms import LLMModel, OpenAIModel +from web_hacker.llms.llm_client import LLMClient +from web_hacker.llms.tools.guide_agent_tools import start_routine_discovery_job_creation +from web_hacker.utils.exceptions import UnknownToolError +from web_hacker.utils.logger import get_logger + + +logger = get_logger(name=__name__) + + +class GuideAgent: + """ + Guide agent that guides the user through the process of creating or editing a routine. + + The agent maintains a ChatThread with Chat messages and uses LLM tool-calling to determine + when to initiate routine discovery. Tool invocations require user confirmation + via callback before execution. + + Usage: + def handle_message(message: EmittedChatMessage) -> None: + print(f"[{message.type}] {message.content}") + + agent = GuideAgent(emit_message_callable=handle_message) + agent.process_user_message("I want to search for flights") + """ + + # Class constants ______________________________________________________________________________________________________ + + SYSTEM_PROMPT: str = """You are a helpful assistant that guides users through creating \ +web automation routines using the Web Hacker tool. + +## What is Web Hacker? + +Web Hacker is a tool that creates reusable web automation routines by learning from \ +user demonstrations. Users record themselves performing a task on a website, and \ +Web Hacker generates a parameterized routine that can be executed programmatically. + +## Your Role + +Your job is to help users define their automation needs by gathering: + +1. **TASK**: What task do they want to automate? + - Examples: "Search for train tickets", "Download a research paper", "Look up company info" + +2. **OUTPUT**: What data/output should the routine return? + - Examples: "List of available trains with prices", "PDF file of the paper", "Company registration details" + +3. **PARAMETERS**: What input parameters will the routine need? For each: + - Name (e.g., "origin", "destination", "departure_date") + - Description (e.g., "The departure station name") + +4. **CONSTRAINTS**: Any filters or constraints? + - Examples: "Only direct trains", "Papers from 2024", "Only active companies" + +5. **WEBSITE**: What website should be used? + - Examples: "amtrak.com", "arxiv.org", "sec.gov" + +## Guidelines + +- Be conversational and helpful +- Ask clarifying questions if needed +- Don't overwhelm the user with too many questions at once +- When you have enough information, use the start_routine_discovery_job_creation tool +- If the user asks what this tool does, explain it clearly""" + + # Magic methods ________________________________________________________________________________________________________ + + def __init__( + self, + emit_message_callable: Callable[[EmittedChatMessage], None], + persist_chat_callable: Callable[[Chat], None] | None = None, + persist_chat_thread_callable: Callable[[ChatThread], None] | None = None, + stream_chunk_callable: Callable[[str], None] | None = None, + llm_model: LLMModel = OpenAIModel.GPT_5_MINI, + chat_thread: ChatThread | None = None, + existing_chats: list[Chat] | None = None, + ) -> None: + """ + Initialize the guide agent. + + Args: + emit_message_callable: Callback function to emit messages to the host. + persist_chat_callable: Optional callback to persist Chat objects (for DynamoDB). + persist_chat_thread_callable: Optional callback to persist ChatThread (for DynamoDB). + stream_chunk_callable: Optional callback for streaming text chunks as they arrive. + llm_model: The LLM model to use for conversation. + chat_thread: Existing ChatThread to continue, or None for new conversation. + existing_chats: Existing Chat messages if loading from persistence. + """ + self._emit_message_callable = emit_message_callable + self._persist_chat_callable = persist_chat_callable + self._persist_chat_thread_callable = persist_chat_thread_callable + self._stream_chunk_callable = stream_chunk_callable + + self.llm_model = llm_model + self.llm_client = LLMClient(llm_model) + + # Register tools + self._register_tools() + + # Initialize or load conversation state + self._thread = chat_thread or ChatThread() + self._chats: dict[str, Chat] = {} + if existing_chats: + for chat in existing_chats: + self._chats[chat.id] = chat + + # Persist initial thread if callback provided + if self._persist_chat_thread_callable and chat_thread is None: + self._persist_chat_thread_callable(self._thread) + + logger.info( + "Instantiated GuideAgent with model: %s, thread_id: %s", + llm_model, + self._thread.id, + ) + + # Properties ___________________________________________________________________________________________________________ + + @property + def thread_id(self) -> str: + """Return the current thread ID.""" + return self._thread.id + + @property + def has_pending_tool_invocation(self) -> bool: + """Check if there's a pending tool invocation awaiting confirmation.""" + return self._thread.pending_tool_invocation is not None + + # Private methods ______________________________________________________________________________________________________ + + def _register_tools(self) -> None: + """Register all tools with the LLM client.""" + self.llm_client.register_tool_from_function( + func=start_routine_discovery_job_creation, + ) + + def _emit_message(self, message: EmittedChatMessage) -> None: + """Emit a message via the callback.""" + self._emit_message_callable(message) + + def _add_chat(self, role: ChatRole, content: str) -> Chat: + """ + Create and store a new Chat, update thread, persist if callbacks set. + + Args: + role: The role of the message sender. + content: The content of the message. + + Returns: + The created Chat object. + """ + chat = Chat( + chat_thread_id=self._thread.id, + role=role, + content=content, + ) + self._chats[chat.id] = chat + self._thread.chat_ids.append(chat.id) + self._thread.updated_at = int(datetime.now().timestamp()) + + # Persist if callbacks provided + if self._persist_chat_callable: + self._persist_chat_callable(chat) + if self._persist_chat_thread_callable: + self._persist_chat_thread_callable(self._thread) + + return chat + + def _build_messages_for_llm(self) -> list[dict[str, str]]: + """ + Build messages list for LLM from Chat objects. + + Returns: + List of message dicts with 'role' and 'content' keys. + """ + messages: list[dict[str, str]] = [] + for chat_id in self._thread.chat_ids: + chat = self._chats.get(chat_id) + if chat: + messages.append({ + "role": chat.role.value, + "content": chat.content, + }) + return messages + + def _create_tool_invocation_request( + self, + tool_name: str, + tool_arguments: dict[str, Any], + ) -> PendingToolInvocation: + """ + Create a tool invocation request for user confirmation. + + Args: + tool_name: Name of the tool to invoke + tool_arguments: Arguments for the tool + + Returns: + PendingToolInvocation stored in state and ready to emit + """ + invocation_id = str(uuid4()) + + pending = PendingToolInvocation( + invocation_id=invocation_id, + tool_name=tool_name, + tool_arguments=tool_arguments, + ) + + # Store in thread + self._thread.pending_tool_invocation = pending + self._thread.updated_at = int(datetime.now().timestamp()) + + if self._persist_chat_thread_callable: + self._persist_chat_thread_callable(self._thread) + + logger.info( + "Created tool invocation request: %s (tool: %s)", + invocation_id, + tool_name, + ) + + return pending + + def _execute_tool( + self, + tool_name: str, + tool_arguments: dict[str, Any], + ) -> dict[str, Any]: + """ + Execute a confirmed tool invocation. + + Args: + tool_name: Name of the tool to execute + tool_arguments: Arguments for the tool + + Returns: + Tool execution result with thread_id and params + + Raises: + UnknownToolError: If tool_name is unknown + """ + if tool_name == start_routine_discovery_job_creation.__name__: + logger.info( + "Executing tool %s with args: %s", + tool_name, + tool_arguments, + ) + result = start_routine_discovery_job_creation(**tool_arguments) + return { + "thread_id": self._thread.id, + **result, + } + + logger.error("Unknown tool \"%s\" with arguments: %s", tool_name, tool_arguments) + raise UnknownToolError(f"Unknown tool \"{tool_name}\" with arguments: {tool_arguments}") + + # Public methods _______________________________________________________________________________________________________ + + def process_user_message(self, content: str) -> None: + """ + Process a user message and emit responses via callback. + + This method handles the conversation loop: + 1. Adds user message to history + 2. Calls LLM to generate response + 3. Emits chat response or tool invocation request + + Args: + content: The user's message content + """ + # Block new messages if there's a pending tool invocation + if self._thread.pending_tool_invocation: + self._emit_message( + EmittedChatMessage( + type=ChatMessageType.ERROR, + error="Please confirm or deny the pending tool invocation before sending new messages", + ) + ) + return + + # Add user message to history + self._add_chat(ChatRole.USER, content) + + # Build messages and call LLM + messages = self._build_messages_for_llm() + + try: + # Use streaming if chunk callback is set + if self._stream_chunk_callable: + response = self._process_streaming_response(messages) + else: + response = self.llm_client.chat_sync( + messages=messages, + system_prompt=self.SYSTEM_PROMPT, + ) + + # Handle text response + if response.content: + self._add_chat(ChatRole.ASSISTANT, response.content) + # Always emit CHAT_RESPONSE (handler checks if streaming occurred) + self._emit_message( + EmittedChatMessage( + type=ChatMessageType.CHAT_RESPONSE, + content=response.content, + ) + ) + + # Handle tool call if present + if response.tool_call: + pending = self._create_tool_invocation_request( + response.tool_call.tool_name, + response.tool_call.tool_arguments, + ) + self._emit_message( + EmittedChatMessage( + type=ChatMessageType.TOOL_INVOCATION_REQUEST, + tool_invocation=pending, + ) + ) + + except Exception as e: + logger.exception("Error processing user message: %s", e) + self._emit_message( + EmittedChatMessage( + type=ChatMessageType.ERROR, + error=str(e), + ) + ) + + def _process_streaming_response(self, messages: list[dict[str, str]]) -> LLMChatResponse: + """ + Process LLM response with streaming, calling chunk callback for each chunk. + + Args: + messages: The messages to send to the LLM. + + Returns: + The final LLMChatResponse with complete content. + """ + response: LLMChatResponse | None = None + + for item in self.llm_client.chat_stream_sync( + messages=messages, + system_prompt=self.SYSTEM_PROMPT, + ): + if isinstance(item, str): + # Text chunk - call the callback + if self._stream_chunk_callable: + self._stream_chunk_callable(item) + elif isinstance(item, LLMChatResponse): + # Final response + response = item + + if response is None: + raise ValueError("No final response received from streaming LLM") + + return response + + def confirm_tool_invocation(self, invocation_id: str) -> None: + """ + Confirm a pending tool invocation and execute it. + + Args: + invocation_id: ID of the tool invocation to confirm + + Emits: + - TOOL_INVOCATION_RESULT with status "executed" and result on success + - ERROR if no pending invocation or ID mismatch + """ + pending = self._thread.pending_tool_invocation + + if not pending: + self._emit_message( + EmittedChatMessage( + type=ChatMessageType.ERROR, + error="No pending tool invocation to confirm", + ) + ) + return + + if pending.invocation_id != invocation_id: + self._emit_message( + EmittedChatMessage( + type=ChatMessageType.ERROR, + error=f"Invocation ID mismatch: expected {pending.invocation_id}", + ) + ) + return + + # Update status + pending.status = ToolInvocationStatus.CONFIRMED + + try: + result = self._execute_tool(pending.tool_name, pending.tool_arguments) + pending.status = ToolInvocationStatus.EXECUTED + self._thread.pending_tool_invocation = None + self._thread.updated_at = int(datetime.now().timestamp()) + + if self._persist_chat_thread_callable: + self._persist_chat_thread_callable(self._thread) + + self._emit_message( + EmittedChatMessage( + type=ChatMessageType.TOOL_INVOCATION_RESULT, + tool_invocation=pending, + tool_result=result, + ) + ) + + logger.info( + "Tool invocation %s executed successfully", + invocation_id, + ) + + except Exception as e: + pending.status = ToolInvocationStatus.FAILED + + self._emit_message( + EmittedChatMessage( + type=ChatMessageType.TOOL_INVOCATION_RESULT, + tool_invocation=pending, + error=str(e), + ) + ) + + logger.exception( + "Tool invocation %s failed: %s", + invocation_id, + e, + ) + + def deny_tool_invocation( + self, + invocation_id: str, + reason: str | None = None, + ) -> None: + """ + Deny a pending tool invocation. + + Args: + invocation_id: ID of the tool invocation to deny + reason: Optional reason for denial + + Emits: + - TOOL_INVOCATION_RESULT with status "denied" + - ERROR if no pending invocation or ID mismatch + """ + pending = self._thread.pending_tool_invocation + + if not pending: + self._emit_message( + EmittedChatMessage( + type=ChatMessageType.ERROR, + error="No pending tool invocation to deny", + ) + ) + return + + if pending.invocation_id != invocation_id: + self._emit_message( + EmittedChatMessage( + type=ChatMessageType.ERROR, + error=f"Invocation ID mismatch: expected {pending.invocation_id}", + ) + ) + return + + # Update status and clear pending + pending.status = ToolInvocationStatus.DENIED + self._thread.pending_tool_invocation = None + self._thread.updated_at = int(datetime.now().timestamp()) + + if self._persist_chat_thread_callable: + self._persist_chat_thread_callable(self._thread) + + # Add denial to conversation history + denial_message = "Tool invocation denied" + if reason: + denial_message += f": {reason}" + self._add_chat(ChatRole.SYSTEM, denial_message) + + self._emit_message( + EmittedChatMessage( + type=ChatMessageType.TOOL_INVOCATION_RESULT, + tool_invocation=pending, + content=denial_message, + ) + ) + + logger.info( + "Tool invocation %s denied: %s", + invocation_id, + reason or "no reason provided", + ) + + def get_thread(self) -> ChatThread: + """ + Get the current conversation thread. + + Returns: + Current ChatThread + """ + return self._thread + + def get_chats(self) -> list[Chat]: + """ + Get all Chat messages in order. + + Returns: + List of Chat objects in conversation order. + """ + return [self._chats[chat_id] for chat_id in self._thread.chat_ids if chat_id in self._chats] + + def reset(self) -> None: + """ + Reset the conversation to a fresh state. + + Generates a new thread and clears all messages. + """ + old_thread_id = self._thread.id + self._thread = ChatThread() + self._chats = {} + + if self._persist_chat_thread_callable: + self._persist_chat_thread_callable(self._thread) + + logger.info( + "Reset conversation from %s to %s", + old_thread_id, + self._thread.id, + ) diff --git a/web_hacker/cdp/connection.py b/web_hacker/cdp/connection.py index 30712c8..645ad5b 100644 --- a/web_hacker/cdp/connection.py +++ b/web_hacker/cdp/connection.py @@ -10,7 +10,6 @@ """ import json -import random import time from json import JSONDecodeError from typing import Callable diff --git a/web_hacker/config.py b/web_hacker/config.py index 7ec9440..50ae9f7 100644 --- a/web_hacker/config.py +++ b/web_hacker/config.py @@ -31,6 +31,7 @@ class Config(): # API keys OPENAI_API_KEY: str | None = os.getenv("OPENAI_API_KEY") + ANTHROPIC_API_KEY: str | None = os.getenv("ANTHROPIC_API_KEY") @classmethod def as_dict(cls) -> dict[str, Any]: diff --git a/web_hacker/data_models/chat.py b/web_hacker/data_models/chat.py new file mode 100644 index 0000000..486be05 --- /dev/null +++ b/web_hacker/data_models/chat.py @@ -0,0 +1,177 @@ +""" +web_hacker/data_models/chat.py + +Chat data models for the guide agent conversation system. +""" + +from datetime import datetime, timezone +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel, Field + +from web_hacker.data_models.resource_base import ResourceBase + + +class ChatRole(StrEnum): + """ + Role in a chat message. + """ + USER = "user" + ASSISTANT = "assistant" # AI + SYSTEM = "system" + TOOL = "tool" + + +class ToolInvocationStatus(StrEnum): + """ + Status of a tool invocation. + """ + PENDING_CONFIRMATION = "pending_confirmation" + CONFIRMED = "confirmed" + DENIED = "denied" + EXECUTED = "executed" + FAILED = "failed" + + +class PendingToolInvocation(BaseModel): + """ + A tool invocation awaiting user confirmation. + """ + invocation_id: str = Field( + ..., + description="Unique ID for this invocation (UUIDv4)", + ) + tool_name: str = Field( + ..., + description="Name of the tool to invoke", + ) + tool_arguments: dict[str, Any] = Field( + ..., + description="Arguments to pass to the tool", + ) + status: ToolInvocationStatus = Field( + default=ToolInvocationStatus.PENDING_CONFIRMATION, + description="Current status of the invocation", + ) + created_at: datetime = Field( + default_factory=lambda: datetime.now(tz=timezone.utc), + description="When the invocation was created", + ) + + +class Chat(ResourceBase): + """ + A single message in a conversation thread. + + Each Chat belongs to a ChatThread and contains a single message + from a user, assistant, system, or tool. + + ID format: Chat_ + """ + + chat_thread_id: str = Field( + ..., + description="ID of the parent ChatThread this message belongs to", + ) + role: ChatRole = Field( + ..., + description="The role of the message sender (user, assistant, system, tool)", + ) + content: str = Field( + ..., + description="The content of the message", + ) + + +class ChatThread(ResourceBase): + """ + Container for a conversation thread with 0+ Chat messages. + + A ChatThread maintains an ordered list of Chat IDs and tracks + any pending tool invocations that require user confirmation. + + ID format: ChatThread_ + """ + + chat_ids: list[str] = Field( + default_factory=list, + description="Ordered list of Chat IDs in this thread (bidirectional link)", + ) + pending_tool_invocation: PendingToolInvocation | None = Field( + default=None, + description="Tool invocation awaiting user confirmation, if any", + ) + + +class ChatMessageType(StrEnum): + """ + Types of messages the chat can emit via callback. + """ + + CHAT_RESPONSE = "chat_response" + TOOL_INVOCATION_REQUEST = "tool_invocation_request" + TOOL_INVOCATION_RESULT = "tool_invocation_result" + ERROR = "error" + + +class EmittedChatMessage(BaseModel): + """ + Message emitted by the guide agent via callback. + + This is the internal message format used by GuideAgent to communicate + with its host (e.g., CLI, WebSocket handler in servers repo). + """ + + type: ChatMessageType = Field( + ..., + description="The type of message being emitted", + ) + timestamp: datetime = Field( + default_factory=lambda: datetime.now(tz=timezone.utc), + description="When the message was created", + ) + content: str | None = Field( + default=None, + description="Text content for chat responses or error messages", + ) + tool_invocation: PendingToolInvocation | None = Field( + default=None, + description="Tool invocation details for request/result messages", + ) + tool_result: dict[str, Any] | None = Field( + default=None, + description="Result data from tool execution", + ) + error: str | None = Field( + default=None, + description="Error message if type is ERROR", + ) + + +class LLMToolCall(BaseModel): + """ + A tool call requested by the LLM. + """ + tool_name: str = Field( + ..., + description="Name of the tool to invoke", + ) + tool_arguments: dict[str, Any] = Field( + ..., + description="Arguments to pass to the tool", + ) + + +class LLMChatResponse(BaseModel): + """ + Response from an LLM chat completion with tool support. + """ + content: str | None = Field( + default=None, + description="Text content of the response", + ) + tool_call: LLMToolCall | None = Field( + default=None, + description="Tool call requested by the LLM, if any", + ) diff --git a/web_hacker/data_models/llms.py b/web_hacker/data_models/llms.py new file mode 100644 index 0000000..3eeb8a9 --- /dev/null +++ b/web_hacker/data_models/llms.py @@ -0,0 +1,51 @@ +""" +web_hacker/data_models/llm_vendor_models.py + +This module contains the LLM vendor models. +""" + +from enum import StrEnum + + +class LLMVendor(StrEnum): + """Represents the vendor of an LLM.""" + OPENAI = "openai" + ANTHROPIC = "anthropic" + + +class OpenAIModel(StrEnum): + """OpenAI models.""" + GPT_5_2 = "gpt-5.2" + GPT_5_MINI = "gpt-5-mini" + GPT_5_NANO = "gpt-5-nano" + + +class AnthropicModel(StrEnum): + """Anthropic models.""" + CLAUDE_OPUS_4_5 = "claude-opus-4-5-20251101" + CLAUDE_SONNET_4_5 = "claude-sonnet-4-5-20250929" + CLAUDE_HAIKU_4_5 = "claude-haiku-4-5-20251001" + + +# Build unified model enum and vendor lookup from vendor-specific enums +_model_to_vendor: dict[str, LLMVendor] = {} +_all_models: dict[str, str] = {} + +for model in OpenAIModel: + _model_to_vendor[model.value] = LLMVendor.OPENAI + _all_models[model.name] = model.value + +for model in AnthropicModel: + _model_to_vendor[model.value] = LLMVendor.ANTHROPIC + _all_models[model.name] = model.value + + +# Union type: any OpenAIModel or AnthropicModel is an LLMModel +type LLMModel = OpenAIModel | AnthropicModel + + +def get_model_vendor(model: LLMModel) -> LLMVendor: + """ + Returns the vendor of the LLM model. + """ + return _model_to_vendor[model.value] diff --git a/web_hacker/data_models/resource_base.py b/web_hacker/data_models/resource_base.py new file mode 100644 index 0000000..2361e99 --- /dev/null +++ b/web_hacker/data_models/resource_base.py @@ -0,0 +1,65 @@ +""" +src/data_models/resource_base.py + +Base class for all resources that provides a standardized ID format. + +ID format: [resourceType]_[uuidv4] +Examples: "Chat_123e4567-e89b-12d3-a456-426614174000" +""" + +from abc import ABC +from datetime import datetime +from typing import Any +from uuid import uuid4 + +from pydantic import BaseModel, Field + + +class ResourceBase(BaseModel, ABC): + """ + Base class for all resources that provides a standardized ID format. + + ID format: [resourceType]_[uuidv4] + Examples: "Chat_123e4567-e89b-12d3-a456-426614174000" + """ + + # standardized resource ID in format "[resourceType]_[uuid]" + id: str = Field( + default_factory=lambda: f"ResourceBase_{uuid4()}", + description="Resource ID in format [resourceType]_[uuidv4]" + ) + + created_at: int = Field( + default_factory=lambda: int(datetime.now().timestamp()), + description="Unix timestamp (seconds) when resource was created" + ) + updated_at: int = Field( + default_factory=lambda: int(datetime.now().timestamp()), + description="Unix timestamp (seconds) when resource was last updated" + ) + metadata: dict[str, Any] | None = Field( + default=None, + description="Metadata for the resource. Anythning that is not suitable for a regular field." + ) + + @property + def resource_type(self) -> str: + """ + Return the resource type name (class name) for this class. + """ + return self.__class__.__name__ + + def __init_subclass__(cls, **kwargs) -> None: + """ + Initialize subclass by setting up the correct default_factory for the id field. + This method is called when a class inherits from ResourceBase. It ensures + that each subclass gets an id field with a default_factory that generates + IDs in the format "[ClassName]_[uuid4]". + Args: + cls: The subclass being initialized + **kwargs: Additional keyword arguments passed to the subclass + """ + super().__init_subclass__(**kwargs) + # override the default_factory for the id field to use the actual class name + if hasattr(cls, 'model_fields') and 'id' in cls.model_fields: + cls.model_fields['id'].default_factory = lambda: f"{cls.__name__}_{uuid4()}" diff --git a/web_hacker/llms/__init__.py b/web_hacker/llms/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/web_hacker/llms/abstract_llm_vendor_client.py b/web_hacker/llms/abstract_llm_vendor_client.py new file mode 100644 index 0000000..097c8ab --- /dev/null +++ b/web_hacker/llms/abstract_llm_vendor_client.py @@ -0,0 +1,239 @@ +""" +web_hacker/llms/abstract_llm_vendor_client.py + +Abstract base class for LLM vendor clients. +""" + +from abc import ABC, abstractmethod +from collections.abc import Generator +from typing import Any, ClassVar, TypeVar + +from pydantic import BaseModel + +from web_hacker.data_models.chat import LLMChatResponse +from web_hacker.data_models.llms import LLMModel + + +T = TypeVar("T", bound=BaseModel) + + +class AbstractLLMVendorClient(ABC): + """ + Abstract base class defining the interface for LLM vendor clients. + + All vendor-specific clients (OpenAI, Anthropic, etc.) must implement + this interface to ensure consistent behavior across the LLMClient. + """ + + # Class attributes ____________________________________________________________________________________________________ + + DEFAULT_MAX_TOKENS: ClassVar[int] = 4_096 + DEFAULT_TEMPERATURE: ClassVar[float] = 0.7 + DEFAULT_STRUCTURED_TEMPERATURE: ClassVar[float] = 0.0 # deterministic for structured outputs + + + # Magic methods ________________________________________________________________________________________________________ + + def __init__(self, model: LLMModel) -> None: + """ + Initialize the vendor client. + + Args: + model: The LLM model to use. + """ + self.model = model + self._tools: list[dict[str, Any]] = [] + + + # Private methods ______________________________________________________________________________________________________ + + def _resolve_max_tokens(self, max_tokens: int | None) -> int: + """Resolve max_tokens, using default if None.""" + return max_tokens if max_tokens is not None else self.DEFAULT_MAX_TOKENS + + def _resolve_temperature( + self, + temperature: float | None, + structured: bool = False, + ) -> float: + """Resolve temperature, using appropriate default if None.""" + if temperature is not None: + return temperature + return self.DEFAULT_STRUCTURED_TEMPERATURE if structured else self.DEFAULT_TEMPERATURE + + + # Public methods _______________________________________________________________________________________________________ + + ### Tool management + + @abstractmethod + def register_tool( + self, + name: str, + description: str, + parameters: dict[str, Any], + ) -> None: + """ + Register a tool for function calling. + + Args: + name: The name of the tool/function. + description: Description of what the tool does. + parameters: JSON Schema describing the tool's parameters. + """ + pass + + def clear_tools(self) -> None: + """Clear all registered tools.""" + self._tools = [] + + @property + def tools(self) -> list[dict[str, Any]]: + """Return the list of registered tools.""" + return self._tools + + ## Text generation + + @abstractmethod + def get_text_sync( + self, + messages: list[dict[str, str]], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> str: + """ + Get a text response synchronously. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + system_prompt: Optional system prompt for context. + max_tokens: Maximum tokens in the response. Defaults to DEFAULT_MAX_TOKENS. + temperature: Sampling temperature (0.0-1.0). Defaults to DEFAULT_TEMPERATURE. + + Returns: + The generated text response. + """ + pass + + @abstractmethod + async def get_text_async( + self, + messages: list[dict[str, str]], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> str: + """ + Get a text response asynchronously. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + system_prompt: Optional system prompt for context. + max_tokens: Maximum tokens in the response. Defaults to DEFAULT_MAX_TOKENS. + temperature: Sampling temperature (0.0-1.0). Defaults to DEFAULT_TEMPERATURE. + + Returns: + The generated text response. + """ + pass + + ## Structured responses + + @abstractmethod + def get_structured_response_sync( + self, + messages: list[dict[str, str]], + response_model: type[T], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> T: + """ + Get a structured response as a Pydantic model synchronously. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + response_model: Pydantic model class for the response structure. + system_prompt: Optional system prompt for context. + max_tokens: Maximum tokens in the response. Defaults to DEFAULT_MAX_TOKENS. + temperature: Sampling temperature. Defaults to DEFAULT_STRUCTURED_TEMPERATURE. + + Returns: + Parsed response as the specified Pydantic model. + """ + pass + + @abstractmethod + async def get_structured_response_async( + self, + messages: list[dict[str, str]], + response_model: type[T], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> T: + """ + Get a structured response as a Pydantic model asynchronously. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + response_model: Pydantic model class for the response structure. + system_prompt: Optional system prompt for context. + max_tokens: Maximum tokens in the response. Defaults to DEFAULT_MAX_TOKENS. + temperature: Sampling temperature. Defaults to DEFAULT_STRUCTURED_TEMPERATURE. + + Returns: + Parsed response as the specified Pydantic model. + """ + pass + + ## Chat with tools + + @abstractmethod + def chat_sync( + self, + messages: list[dict[str, str]], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> LLMChatResponse: + """ + Chat with the LLM using a message history, with tool calling support. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + system_prompt: Optional system prompt for context. + max_tokens: Maximum tokens in the response. Defaults to DEFAULT_MAX_TOKENS. + temperature: Sampling temperature (0.0-1.0). Defaults to DEFAULT_TEMPERATURE. + + Returns: + LLMChatResponse with text content and optional tool call. + """ + pass + + @abstractmethod + def chat_stream_sync( + self, + messages: list[dict[str, str]], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> Generator[str | LLMChatResponse, None, None]: + """ + Chat with streaming, yielding text chunks and final LLMChatResponse. + + Yields text chunks as they arrive, then yields the final LLMChatResponse + with the complete content and any tool call. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + system_prompt: Optional system prompt for context. + max_tokens: Maximum tokens in the response. Defaults to DEFAULT_MAX_TOKENS. + temperature: Sampling temperature (0.0-1.0). Defaults to DEFAULT_TEMPERATURE. + + Yields: + str: Text chunks as they arrive. + LLMChatResponse: Final response with complete content and optional tool call. + """ + pass diff --git a/web_hacker/llms/anthropic_client.py b/web_hacker/llms/anthropic_client.py new file mode 100644 index 0000000..cb580db --- /dev/null +++ b/web_hacker/llms/anthropic_client.py @@ -0,0 +1,311 @@ +""" +web_hacker/llms/anthropic_client.py + +Anthropic-specific LLM client implementation. +""" + +from collections.abc import Generator +from typing import Any, TypeVar + +from anthropic import Anthropic, AsyncAnthropic +from pydantic import BaseModel + +from web_hacker.data_models.chat import LLMChatResponse, LLMToolCall +from web_hacker.data_models.llms import LLMModel +from web_hacker.llms.abstract_llm_vendor_client import AbstractLLMVendorClient +from web_hacker.config import Config +from web_hacker.utils.logger import get_logger + +logger = get_logger(name=__name__) + + +T = TypeVar("T", bound=BaseModel) + + +class AnthropicClient(AbstractLLMVendorClient): + """ + Anthropic-specific implementation of the LLM vendor client. + + Uses the Anthropic Python SDK for message completions, structured outputs + via tool use, and function calling. + """ + + # Magic methods ________________________________________________________________________________________________________ + + def __init__(self, model: LLMModel) -> None: + super().__init__(model) + self._client = Anthropic(api_key=Config.ANTHROPIC_API_KEY) + self._async_client = AsyncAnthropic(api_key=Config.ANTHROPIC_API_KEY) + logger.debug("Initialized AnthropicClient with model: %s", model) + + + # Private methods ______________________________________________________________________________________________________ + + def _extract_text_content(self, content: list[Any]) -> str: + """Extract text from Anthropic content blocks.""" + text_parts = [block.text for block in content if hasattr(block, "text")] + return "".join(text_parts) + + def _build_extraction_tool( + self, + response_model: type[T], + ) -> dict[str, Any]: + """Build the extraction tool for structured responses.""" + tool_schema = response_model.model_json_schema() + return { + "name": "extract_data", + "description": f"Extract data matching the {response_model.__name__} schema.", + "input_schema": tool_schema, + } + + def _append_extraction_instruction( + self, + messages: list[dict[str, str]], + ) -> list[dict[str, str]]: + """Append extraction instruction to the last user message.""" + if not messages: + return messages + + # Copy messages to avoid mutating the original + messages = [dict(m) for m in messages] + + # Find the last user message and append instruction + for i in range(len(messages) - 1, -1, -1): + if messages[i].get("role") == "user": + messages[i]["content"] = ( + f"{messages[i]['content']}\n\nUse the 'extract_data' tool to provide your response " + f"in the exact schema specified." + ) + break + + return messages + + def _extract_tool_result( + self, + content: list[Any], + response_model: type[T], + ) -> T: + """Extract and validate the tool use result from Anthropic response.""" + for block in content: + if hasattr(block, "input") and hasattr(block, "name") and block.name == "extract_data": + return response_model.model_validate(block.input) + raise ValueError("Failed to extract structured response from Anthropic") + + + # Public methods _______________________________________________________________________________________________________ + + ## Tool management + + def register_tool( + self, + name: str, + description: str, + parameters: dict[str, Any], + ) -> None: + """Register a tool in Anthropic's tool format.""" + logger.debug("Registering Anthropic tool: %s", name) + self._tools.append({ + "name": name, + "description": description, + "input_schema": parameters, + }) + + ## Text generation + + def get_text_sync( + self, + messages: list[dict[str, str]], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> str: + """Get a text response synchronously using Anthropic messages API.""" + kwargs: dict[str, Any] = { + "model": self.model.value, + "messages": messages, + "max_tokens": self._resolve_max_tokens(max_tokens), + "temperature": self._resolve_temperature(temperature), + } + if system_prompt: + kwargs["system"] = system_prompt + if self._tools: + kwargs["tools"] = self._tools + + response = self._client.messages.create(**kwargs) + return self._extract_text_content(response.content) + + async def get_text_async( + self, + messages: list[dict[str, str]], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> str: + """Get a text response asynchronously using Anthropic messages API.""" + kwargs: dict[str, Any] = { + "model": self.model.value, + "messages": messages, + "max_tokens": self._resolve_max_tokens(max_tokens), + "temperature": self._resolve_temperature(temperature), + } + if system_prompt: + kwargs["system"] = system_prompt + if self._tools: + kwargs["tools"] = self._tools + + response = await self._async_client.messages.create(**kwargs) + return self._extract_text_content(response.content) + + ## Structured responses + + def get_structured_response_sync( + self, + messages: list[dict[str, str]], + response_model: type[T], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> T: + """Get a structured response using Anthropic's tool_use with forced tool choice.""" + tool = self._build_extraction_tool(response_model) + structured_messages = self._append_extraction_instruction(messages) + + kwargs: dict[str, Any] = { + "model": self.model.value, + "messages": structured_messages, + "max_tokens": self._resolve_max_tokens(max_tokens), + "temperature": self._resolve_temperature(temperature, structured=True), + "tools": [tool], + "tool_choice": {"type": "tool", "name": "extract_data"}, + } + if system_prompt: + kwargs["system"] = system_prompt + + response = self._client.messages.create(**kwargs) + return self._extract_tool_result(response.content, response_model) + + async def get_structured_response_async( + self, + messages: list[dict[str, str]], + response_model: type[T], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> T: + """Get a structured response asynchronously using Anthropic's tool_use.""" + tool = self._build_extraction_tool(response_model) + structured_messages = self._append_extraction_instruction(messages) + + kwargs: dict[str, Any] = { + "model": self.model.value, + "messages": structured_messages, + "max_tokens": self._resolve_max_tokens(max_tokens), + "temperature": self._resolve_temperature(temperature, structured=True), + "tools": [tool], + "tool_choice": {"type": "tool", "name": "extract_data"}, + } + if system_prompt: + kwargs["system"] = system_prompt + + response = await self._async_client.messages.create(**kwargs) + return self._extract_tool_result(response.content, response_model) + + ## Chat with tools + + def chat_sync( + self, + messages: list[dict[str, str]], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> LLMChatResponse: + """Chat with Anthropic using message history and tool calling support.""" + kwargs: dict[str, Any] = { + "model": self.model.value, + "messages": messages, + "max_tokens": self._resolve_max_tokens(max_tokens), + "temperature": self._resolve_temperature(temperature), + } + if system_prompt: + kwargs["system"] = system_prompt + if self._tools: + kwargs["tools"] = self._tools + + response = self._client.messages.create(**kwargs) + + # Extract text content + text_content = self._extract_text_content(response.content) + + # Extract tool call if present + tool_call: LLMToolCall | None = None + for block in response.content: + if hasattr(block, "input") and hasattr(block, "name"): + tool_call = LLMToolCall( + tool_name=block.name, + tool_arguments=block.input, + ) + break + + return LLMChatResponse( + content=text_content if text_content else None, + tool_call=tool_call, + ) + + def chat_stream_sync( + self, + messages: list[dict[str, str]], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> Generator[str | LLMChatResponse, None, None]: + """Chat with streaming, yielding text chunks and final LLMChatResponse.""" + kwargs: dict[str, Any] = { + "model": self.model.value, + "messages": messages, + "max_tokens": self._resolve_max_tokens(max_tokens), + "temperature": self._resolve_temperature(temperature), + } + if system_prompt: + kwargs["system"] = system_prompt + if self._tools: + kwargs["tools"] = self._tools + + full_content: list[str] = [] + tool_name: str | None = None + tool_input: dict[str, Any] = {} + + with self._client.messages.stream(**kwargs) as stream: + for event in stream: + # Handle text delta events + if hasattr(event, "type"): + if event.type == "content_block_delta": + if hasattr(event.delta, "text"): + full_content.append(event.delta.text) + yield event.delta.text + elif hasattr(event.delta, "partial_json"): + # Tool input being streamed + pass + elif event.type == "content_block_start": + if hasattr(event.content_block, "name"): + tool_name = event.content_block.name + + # Get final message for tool input + final_message = stream.get_final_message() + for block in final_message.content: + if hasattr(block, "input") and hasattr(block, "name"): + tool_name = block.name + tool_input = block.input + break + + # Build final response + tool_call: LLMToolCall | None = None + if tool_name: + tool_call = LLMToolCall( + tool_name=tool_name, + tool_arguments=tool_input, + ) + + yield LLMChatResponse( + content="".join(full_content) if full_content else None, + tool_call=tool_call, + ) diff --git a/web_hacker/llms/llm_client.py b/web_hacker/llms/llm_client.py new file mode 100644 index 0000000..4d27cde --- /dev/null +++ b/web_hacker/llms/llm_client.py @@ -0,0 +1,274 @@ +""" +web_hacker/llms/llm_client.py + +Unified LLM client supporting OpenAI and Anthropic models. +""" + +from collections.abc import Generator +from typing import Any, Callable, TypeVar + +from pydantic import BaseModel + +from web_hacker.data_models.chat import LLMChatResponse +from web_hacker.data_models.llms import ( + LLMModel, + LLMVendor, + OpenAIModel, + get_model_vendor, +) +from web_hacker.llms.tools.tool_utils import extract_description_from_docstring, generate_parameters_schema +from web_hacker.llms.abstract_llm_vendor_client import AbstractLLMVendorClient +from web_hacker.llms.anthropic_client import AnthropicClient +from web_hacker.llms.openai_client import OpenAIClient +from web_hacker.utils.logger import get_logger + +logger = get_logger(name=__name__) + + +T = TypeVar("T", bound=BaseModel) + + +class LLMClient: + """ + Unified LLM client class for interacting with OpenAI and Anthropic APIs. + + This is a facade that delegates to vendor-specific clients (OpenAIClient, + AnthropicClient) based on the selected model. + + Supports: + - Sync and async text generation + - Structured responses using Pydantic models + - Tool/function registration + """ + + # Magic methods ________________________________________________________________________________________________________ + + def __init__( + self, + llm_model: LLMModel = OpenAIModel.GPT_5_MINI, + ) -> None: + self.llm_model = llm_model + self.vendor = get_model_vendor(llm_model) + + # initialize the appropriate vendor client + self._vendor_client: AbstractLLMVendorClient + if self.vendor == LLMVendor.OPENAI: + self._vendor_client = OpenAIClient(model=llm_model) + elif self.vendor == LLMVendor.ANTHROPIC: + self._vendor_client = AnthropicClient(model=llm_model) + else: + raise ValueError(f"Unsupported vendor: {self.vendor}") + + logger.info("Instantiated LLMClient with model: %s (vendor: %s)", llm_model, self.vendor) + + # Public methods _______________________________________________________________________________________________________ + + ## Tools + + def register_tool( + self, + name: str, + description: str, + parameters: dict[str, Any], + ) -> None: + """ + Register a tool for function calling. + + Args: + name: The name of the tool/function. + description: Description of what the tool does. + parameters: JSON Schema describing the tool's parameters. + """ + logger.debug("Registering tool %s (description: %s) with parameters: %s", name, description, parameters) + self._vendor_client.register_tool(name, description, parameters) + + def register_tool_from_function(self, func: Callable[..., Any]) -> None: + """ + Register a tool from a Python function, extracting metadata automatically. + + Extracts: + - name: from func.__name__ + - description: from the docstring (first paragraph) + - parameters: JSON Schema generated from type hints via pydantic + + Args: + func: The function to register as a tool. Must have type hints. + """ + name = func.__name__ + description = extract_description_from_docstring(func.__doc__) + parameters = generate_parameters_schema(func) + self.register_tool(name, description, parameters) + + def clear_tools(self) -> None: + """Clear all registered tools.""" + self._vendor_client.clear_tools() + logger.debug("Cleared all registered tools") + + ## Text generation + + def get_text_sync( + self, + messages: list[dict[str, str]], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> str: + """ + Get a text response synchronously. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + system_prompt: Optional system prompt for context. + max_tokens: Maximum tokens in the response. Defaults to DEFAULT_MAX_TOKENS. + temperature: Sampling temperature (0.0-1.0). Defaults to DEFAULT_TEMPERATURE. + + Returns: + The generated text response. + """ + return self._vendor_client.get_text_sync( + messages, + system_prompt, + max_tokens, + temperature, + ) + + async def get_text_async( + self, + messages: list[dict[str, str]], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> str: + """ + Get a text response asynchronously. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + system_prompt: Optional system prompt for context. + max_tokens: Maximum tokens in the response. Defaults to DEFAULT_MAX_TOKENS. + temperature: Sampling temperature (0.0-1.0). Defaults to DEFAULT_TEMPERATURE. + + Returns: + The generated text response. + """ + return await self._vendor_client.get_text_async( + messages, + system_prompt, + max_tokens, + temperature, + ) + + ## Structured responses + + def get_structured_response_sync( + self, + messages: list[dict[str, str]], + response_model: type[T], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> T: + """ + Get a structured response as a Pydantic model synchronously. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + response_model: Pydantic model class for the response structure. + system_prompt: Optional system prompt for context. + max_tokens: Maximum tokens in the response. Defaults to DEFAULT_MAX_TOKENS. + temperature: Sampling temperature. Defaults to DEFAULT_STRUCTURED_TEMPERATURE. + + Returns: + Parsed response as the specified Pydantic model. + """ + return self._vendor_client.get_structured_response_sync( + messages, + response_model, + system_prompt, + max_tokens, + temperature, + ) + + async def get_structured_response_async( + self, + messages: list[dict[str, str]], + response_model: type[T], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> T: + """ + Get a structured response as a Pydantic model asynchronously. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + response_model: Pydantic model class for the response structure. + system_prompt: Optional system prompt for context. + max_tokens: Maximum tokens in the response. Defaults to DEFAULT_MAX_TOKENS. + temperature: Sampling temperature. Defaults to DEFAULT_STRUCTURED_TEMPERATURE. + + Returns: + Parsed response as the specified Pydantic model. + """ + return await self._vendor_client.get_structured_response_async( + messages, + response_model, + system_prompt, + max_tokens, + temperature, + ) + + ## Chat with tools + + def chat_sync( + self, + messages: list[dict[str, str]], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> LLMChatResponse: + """ + Chat with the LLM using a message history, with tool calling support. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + system_prompt: Optional system prompt for context. + max_tokens: Maximum tokens in the response. Defaults to DEFAULT_MAX_TOKENS. + temperature: Sampling temperature (0.0-1.0). Defaults to DEFAULT_TEMPERATURE. + + Returns: + LLMChatResponse with text content and optional tool call. + """ + return self._vendor_client.chat_sync( + messages, + system_prompt, + max_tokens, + temperature, + ) + + def chat_stream_sync( + self, + messages: list[dict[str, str]], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> Generator[str | LLMChatResponse, None, None]: + """ + Chat with streaming, yielding text chunks and final LLMChatResponse. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + system_prompt: Optional system prompt for context. + max_tokens: Maximum tokens in the response. Defaults to DEFAULT_MAX_TOKENS. + temperature: Sampling temperature (0.0-1.0). Defaults to DEFAULT_TEMPERATURE. + + Yields: + str: Text chunks as they arrive. + LLMChatResponse: Final response with complete content and optional tool call. + """ + yield from self._vendor_client.chat_stream_sync( + messages, + system_prompt, + max_tokens, + temperature, + ) diff --git a/web_hacker/llms/openai_client.py b/web_hacker/llms/openai_client.py new file mode 100644 index 0000000..e4e1f14 --- /dev/null +++ b/web_hacker/llms/openai_client.py @@ -0,0 +1,273 @@ +""" +web_hacker/llms/openai_client.py + +OpenAI-specific LLM client implementation. +""" + +import json +from collections.abc import Generator +from typing import Any, TypeVar + +from openai import OpenAI, AsyncOpenAI +from pydantic import BaseModel + +from web_hacker.data_models.chat import LLMChatResponse, LLMToolCall +from web_hacker.data_models.llms import LLMModel +from web_hacker.llms.abstract_llm_vendor_client import AbstractLLMVendorClient +from web_hacker.config import Config +from web_hacker.utils.logger import get_logger + +logger = get_logger(name=__name__) + + +T = TypeVar("T", bound=BaseModel) + + +class OpenAIClient(AbstractLLMVendorClient): + """ + OpenAI-specific implementation of the LLM vendor client. + + Uses the OpenAI Python SDK for chat completions, structured outputs, + and function calling. + """ + + # Magic methods ________________________________________________________________________________________________________ + + def __init__(self, model: LLMModel) -> None: + super().__init__(model) + self._client = OpenAI(api_key=Config.OPENAI_API_KEY) + self._async_client = AsyncOpenAI(api_key=Config.OPENAI_API_KEY) + logger.debug("Initialized OpenAIClient with model: %s", model) + + # Private methods ______________________________________________________________________________________________________ + + def _prepend_system_prompt( + self, + messages: list[dict[str, str]], + system_prompt: str | None, + ) -> list[dict[str, str]]: + """ + Prepend system prompt to messages if provided. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + system_prompt: Optional system prompt for context. + + Returns: + Messages list with system prompt prepended if provided. + """ + if system_prompt: + return [{"role": "system", "content": system_prompt}] + messages + return messages + + # Public methods _______________________________________________________________________________________________________ + + ## Tool management + + def register_tool( + self, + name: str, + description: str, + parameters: dict[str, Any], + ) -> None: + """Register a tool in OpenAI's function calling format.""" + logger.debug("Registering OpenAI tool: %s", name) + self._tools.append({ + "type": "function", + "function": { + "name": name, + "description": description, + "parameters": parameters, + } + }) + + ## Text generation + + def get_text_sync( + self, + messages: list[dict[str, str]], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> str: + """Get a text response synchronously using OpenAI chat completions.""" + all_messages = self._prepend_system_prompt(messages, system_prompt) + + kwargs: dict[str, Any] = { + "model": self.model.value, + "messages": all_messages, + "max_completion_tokens": self._resolve_max_tokens(max_tokens), + # Note: GPT-5 models only support temperature=1 (default), so we omit it + } + if self._tools: + kwargs["tools"] = self._tools + + response = self._client.chat.completions.create(**kwargs) + return response.choices[0].message.content or "" + + async def get_text_async( + self, + messages: list[dict[str, str]], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> str: + """Get a text response asynchronously using OpenAI chat completions.""" + all_messages = self._prepend_system_prompt(messages, system_prompt) + + kwargs: dict[str, Any] = { + "model": self.model.value, + "messages": all_messages, + "max_completion_tokens": self._resolve_max_tokens(max_tokens), + # Note: GPT-5 models only support temperature=1 (default), so we omit it + } + if self._tools: + kwargs["tools"] = self._tools + + response = await self._async_client.chat.completions.create(**kwargs) + return response.choices[0].message.content or "" + + ## Structured responses + + def get_structured_response_sync( + self, + messages: list[dict[str, str]], + response_model: type[T], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> T: + """Get a structured response using OpenAI's native response_format parsing.""" + all_messages = self._prepend_system_prompt(messages, system_prompt) + + response = self._client.beta.chat.completions.parse( + model=self.model.value, + messages=all_messages, + response_format=response_model, + max_completion_tokens=self._resolve_max_tokens(max_tokens), + # Note: GPT-5 models only support temperature=1 (default), so we omit it + ) + parsed = response.choices[0].message.parsed + if parsed is None: + raise ValueError("Failed to parse structured response from OpenAI") + return parsed + + async def get_structured_response_async( + self, + messages: list[dict[str, str]], + response_model: type[T], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> T: + """Get a structured response asynchronously using OpenAI's native response_format parsing.""" + all_messages = self._prepend_system_prompt(messages, system_prompt) + + response = await self._async_client.beta.chat.completions.parse( + model=self.model.value, + messages=all_messages, + response_format=response_model, + max_completion_tokens=self._resolve_max_tokens(max_tokens), + # Note: GPT-5 models only support temperature=1 (default), so we omit it + ) + parsed = response.choices[0].message.parsed + if parsed is None: + raise ValueError("Failed to parse structured response from OpenAI") + return parsed + + ## Chat with tools + + def chat_sync( + self, + messages: list[dict[str, str]], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> LLMChatResponse: + """Chat with OpenAI using message history and tool calling support.""" + all_messages = self._prepend_system_prompt(messages, system_prompt) + + kwargs: dict[str, Any] = { + "model": self.model.value, + "messages": all_messages, + "max_completion_tokens": self._resolve_max_tokens(max_tokens), + # Note: GPT-5 models only support temperature=1 (default), so we omit it + } + if self._tools: + kwargs["tools"] = self._tools + + response = self._client.chat.completions.create(**kwargs) + message = response.choices[0].message + + # Extract tool call if present + tool_call: LLMToolCall | None = None + if message.tool_calls and len(message.tool_calls) > 0: + tc = message.tool_calls[0] + tool_call = LLMToolCall( + tool_name=tc.function.name, + tool_arguments=json.loads(tc.function.arguments), + ) + + return LLMChatResponse( + content=message.content, + tool_call=tool_call, + ) + + def chat_stream_sync( + self, + messages: list[dict[str, str]], + system_prompt: str | None = None, + max_tokens: int | None = None, + temperature: float | None = None, + ) -> Generator[str | LLMChatResponse, None, None]: + """Chat with streaming, yielding text chunks and final LLMChatResponse.""" + all_messages = self._prepend_system_prompt(messages, system_prompt) + + kwargs: dict[str, Any] = { + "model": self.model.value, + "messages": all_messages, + "max_completion_tokens": self._resolve_max_tokens(max_tokens), + "stream": True, + # Note: GPT-5 models only support temperature=1 (default), so we omit it + } + if self._tools: + kwargs["tools"] = self._tools + + stream = self._client.chat.completions.create(**kwargs) + + # Accumulate content and tool call data + full_content: list[str] = [] + tool_call_name: str | None = None + tool_call_args: list[str] = [] + + for chunk in stream: + delta = chunk.choices[0].delta if chunk.choices else None + if delta is None: + continue + + # Handle text content + if delta.content: + full_content.append(delta.content) + yield delta.content + + # Handle tool calls (streamed in chunks) + if delta.tool_calls: + for tc in delta.tool_calls: + if tc.function: + if tc.function.name: + tool_call_name = tc.function.name + if tc.function.arguments: + tool_call_args.append(tc.function.arguments) + + # Build final response + tool_call: LLMToolCall | None = None + if tool_call_name: + tool_call = LLMToolCall( + tool_name=tool_call_name, + tool_arguments=json.loads("".join(tool_call_args)) if tool_call_args else {}, + ) + + yield LLMChatResponse( + content="".join(full_content) if full_content else None, + tool_call=tool_call, + ) diff --git a/web_hacker/llms/tools/__init__.py b/web_hacker/llms/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/web_hacker/llms/tools/guide_agent_tools.py b/web_hacker/llms/tools/guide_agent_tools.py new file mode 100644 index 0000000..6cba95c --- /dev/null +++ b/web_hacker/llms/tools/guide_agent_tools.py @@ -0,0 +1,39 @@ +""" +web_hacker/llms/tools/guide_agent_tools.py + +Tool functions for the guide agent. +""" + +from typing import Any + + +def start_routine_discovery_job_creation( + task_description: str, + expected_output_description: str, + input_parameters: list[dict[str, str]] | None = None, + filters_or_constraints: list[str] | None = None, + target_website: str | None = None, +) -> dict[str, Any]: + """ + Initiates the routine discovery process. + + Call this when you have gathered enough information about: + 1) What task the user wants to automate + 2) What data/output they expect + 3) What input parameters the routine should accept + 4) Any filters or constraints + + This tool requests user confirmation before executing. + + Args: + task_description: Description of the task/routine the user wants to create + expected_output_description: Description of what data the routine should return + input_parameters: List of input parameters with 'name' and 'description' keys + filters_or_constraints: Any filters or constraints the user mentioned + target_website: Target website/URL if mentioned by user + + Returns: + Result dict to be passed to routine discovery agent + """ + # TODO: implement the actual handoff logic + raise NotImplementedError("start_routine_discovery_job_creation not yet implemented") diff --git a/web_hacker/llms/tools/tool_utils.py b/web_hacker/llms/tools/tool_utils.py new file mode 100644 index 0000000..fcc4fe3 --- /dev/null +++ b/web_hacker/llms/tools/tool_utils.py @@ -0,0 +1,74 @@ +""" +web_hacker/llms/tools/tool_utils.py + +Utilities for converting Python functions to LLM tool definitions. +""" + +import inspect +from typing import Any, Callable, get_type_hints + +from pydantic import TypeAdapter + + +def extract_description_from_docstring(docstring: str | None) -> str: + """ + Extract the first paragraph from a docstring as the description. + + Args: + docstring: The function's docstring (func.__doc__) + + Returns: + The first paragraph of the docstring, or empty string if none. + """ + if not docstring: + return "" + + # split on double newlines to get first paragraph + paragraphs = docstring.strip().split("\n\n") + if not paragraphs: + return "" + + # clean up the first paragraph (remove leading/trailing whitespace from each line) + first_para = paragraphs[0] + lines = [line.strip() for line in first_para.split("\n")] + return " ".join(lines) + + +def generate_parameters_schema(func: Callable[..., Any]) -> dict[str, Any]: + """ + Generate JSON Schema for function parameters using pydantic. + + Args: + func: The function to generate schema for. Must have type hints. + + Returns: + JSON Schema dict with 'type', 'properties', and 'required' keys. + """ + sig = inspect.signature(obj=func) + hints = get_type_hints(obj=func) + + properties: dict[str, Any] = {} + required: list[str] = [] + + for param_name, param in sig.parameters.items(): + if param_name in ("self", "cls"): + continue + + param_type = hints.get(param_name, Any) + # use pydantic TypeAdapter to generate schema for this type + schema = TypeAdapter(param_type).json_schema() + + # remove pydantic metadata that's not needed for tool schemas + schema.pop("title", None) + + properties[param_name] = schema + + # parameter is required if it has no default value + if param.default is inspect.Parameter.empty: + required.append(param_name) + + return { + "type": "object", + "properties": properties, + "required": required, + } diff --git a/web_hacker/routine_discovery/context_manager.py b/web_hacker/routine_discovery/context_manager.py index 73f90ac..2b16b4f 100644 --- a/web_hacker/routine_discovery/context_manager.py +++ b/web_hacker/routine_discovery/context_manager.py @@ -1,10 +1,17 @@ -from pydantic import BaseModel, field_validator, Field, ConfigDict -from openai import OpenAI -from abc import ABC, abstractmethod -import os +""" +web_hacker/routine_discovery/context_manager.py + +Context management abstractions and utilities for routine discovery. +""" + import json -import time +import os import shutil +import time +from abc import ABC, abstractmethod + +from openai import OpenAI +from pydantic import BaseModel, ConfigDict, Field, field_validator from web_hacker.utils.data_utils import get_text_from_html diff --git a/web_hacker/scripts/run_guide_agent.py b/web_hacker/scripts/run_guide_agent.py new file mode 100644 index 0000000..0d73f11 --- /dev/null +++ b/web_hacker/scripts/run_guide_agent.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +""" +web_hacker/scripts/run_guide_agent.py + +Terminal-based chat interface for the Guide Agent. + +Usage: + python -m web_hacker.scripts.run_guide_agent +""" + +import sys + +from web_hacker.agents.guide_agent.guide_agent import GuideAgent +from web_hacker.data_models.chat import Chat, ChatThread, EmittedChatMessage, ChatMessageType +from web_hacker.data_models.llms import OpenAIModel +from web_hacker.config import Config +from web_hacker.utils.exceptions import ApiKeyNotFoundError +from web_hacker.utils.logger import get_logger + +logger = get_logger(__name__) + + +class TerminalGuideChat: + """Terminal interface for Guide Agent.""" + + def __init__(self) -> None: + self.agent = GuideAgent( + emit_message_callable=self._handle_emitted_message, + persist_chat_callable=self._persist_chat, + persist_chat_thread_callable=self._persist_chat_thread, + llm_model=OpenAIModel.GPT_5_MINI, + ) + + def _persist_chat(self, chat: Chat) -> None: + """Persist a Chat object. In the future, this will POST to DynamoDB.""" + logger.debug("Would persist Chat: %s", chat.id) + + def _persist_chat_thread(self, thread: ChatThread) -> None: + """Persist a ChatThread object. In the future, this will POST/PATCH to DynamoDB.""" + logger.debug("Would persist ChatThread: %s", thread.id) + + def _handle_emitted_message(self, message: EmittedChatMessage) -> None: + """Handle messages emitted by the agent.""" + if message.type == ChatMessageType.CHAT_RESPONSE: + if message.content: + print(f"\nAssistant: {message.content}") + + elif message.type == ChatMessageType.TOOL_INVOCATION_REQUEST: + print("\n" + "=" * 60) + print("TOOL INVOCATION REQUEST") + if message.tool_invocation: + print(f"Tool: {message.tool_invocation.tool_name}") + print(f"Arguments: {message.tool_invocation.tool_arguments}") + print("=" * 60) + print("Type 'yes' to confirm or 'no' to deny:") + + elif message.type == ChatMessageType.TOOL_INVOCATION_RESULT: + print("\n" + "-" * 60) + if message.tool_result: + print(f"Tool executed successfully!") + print(f"Result: {message.tool_result}") + elif message.error: + print(f"Tool execution failed: {message.error}") + elif message.content: + print(message.content) + print("-" * 60) + + elif message.type == ChatMessageType.ERROR: + print(f"\nError: {message.error}") + + def run(self) -> None: + """Run the terminal chat loop.""" + print("=" * 60) + print("Welcome to the Web Hacker Guide Agent") + print("I'll help you define your web automation routine.") + print("Type 'quit' or 'exit' to end the conversation.") + print("=" * 60) + print() + + while True: + try: + # Get user input + user_input = input("You: ").strip() + + if not user_input: + continue + + if user_input.lower() in ("quit", "exit"): + print("Goodbye!") + sys.exit(0) + + # Handle tool confirmation + if self.agent.has_pending_tool_invocation: + pending = self.agent.get_thread().pending_tool_invocation + if pending: + if user_input.lower() in ("yes", "y", "confirm"): + self.agent.confirm_tool_invocation(pending.invocation_id) + elif user_input.lower() in ("no", "n", "deny"): + self.agent.deny_tool_invocation(pending.invocation_id, "User denied") + else: + print("Please type 'yes' to confirm or 'no' to deny.") + continue + + # Process message + self.agent.process_user_message(user_input) + + except KeyboardInterrupt: + print("\nGoodbye!") + sys.exit(0) + except Exception as e: + logger.exception("Error in chat loop: %s", e) + print(f"\nAn error occurred: {e}") + + +def main() -> None: + """Entry point for the terminal chat.""" + # Ensure OpenAI API key is set + if Config.OPENAI_API_KEY is None: + logger.error("OPENAI_API_KEY is not set") + raise ApiKeyNotFoundError("OPENAI_API_KEY is not set") + + chat = TerminalGuideChat() + chat.run() + + +if __name__ == "__main__": + main() diff --git a/web_hacker/utils/exceptions.py b/web_hacker/utils/exceptions.py index ed299f6..8b82343 100644 --- a/web_hacker/utils/exceptions.py +++ b/web_hacker/utils/exceptions.py @@ -81,3 +81,9 @@ class WebHackerError(Exception): """ Base exception for all Web Hacker errors. """ + + +class UnknownToolError(Exception): + """ + Raised when attempting to execute a tool that does not exist. + """