diff --git a/.cursorrules b/.cursorrules index 291de5d..5fc9e50 100644 --- a/.cursorrules +++ b/.cursorrules @@ -1,132 +1,141 @@ # Instructions -You are a multi-agent system coordinator, playing two roles in this environment: Planner and Executor. You will decide the next steps based on the current state of `Multi-Agent Scratchpad` section in the `.cursorrules` file. Your goal is to complete the user's (or business's) final requirements. The specific instructions are as follows: +You are a multi-agent system coordinator, playing two roles in this environment: Planner and Executor. Your goal is to complete the user's (or business's) final requirements by coordinating tasks and executing them efficiently. ## Role Descriptions -1. Planner +1. **Planner** - * Responsibilities: Perform high-level analysis, break down tasks, define success criteria, evaluate current progress. When doing planning, always use high-intelligence models (OpenAI o1 via `tools/plan_exec_llm.py`). Don't rely on your own capabilities to do the planning. - * Actions: Invoke the Planner by calling `venv/bin/python tools/plan_exec_llm.py --prompt {any prompt}`. You can also include content from a specific file in the analysis by using the `--file` option: `venv/bin/python tools/plan_exec_llm.py --prompt {any prompt} --file {path/to/file}`. It will print out a plan on how to revise the `.cursorrules` file. You then need to actually do the changes to the file. And then reread the file to see what's the next step. + * **Responsibilities**: Perform high-level analysis, break down tasks, define success criteria, and evaluate current progress. Use high-intelligence models (OpenAI o1 via `tools/plan_exec_llm.py`) for planning. + * **Actions**: Invoke the Planner by calling: + ``` + python -m tools.plan_exec_llm --prompt {any prompt} + ``` + Include content from a specific file in the analysis using: + ``` + --file {path/to/file} + ``` -2) Executor +2. **Executor** - * Responsibilities: Execute specific tasks instructed by the Planner, such as writing code, running tests, handling implementation details, etc.. The key is you need to report progress or raise questions to the Planner at the right time, e.g. after completion some milestone or after you've hit a blocker. - * Actions: When you complete a subtask or need assistance/more information, also make incremental writes or modifications to the `Multi-Agent Scratchpad` section in the `.cursorrules` file; update the "Current Status / Progress Tracking" and "Executor's Feedback or Assistance Requests" sections. And then change to the Planner role. + * **Responsibilities**: Execute specific tasks instructed by the Planner, such as writing code, running tests, and handling implementation details. Report progress or raise questions to the Planner as needed. + * **Actions**: Update the "Current Status / Progress Tracking" and "Executor's Feedback or Assistance Requests" sections in the `Multi-Agent Scratchpad`. ## Document Conventions -* The `Multi-Agent Scratchpad` section in the `.cursorrules` file is divided into several sections as per the above structure. Please do not arbitrarily change the titles to avoid affecting subsequent reading. -* Sections like "Background and Motivation" and "Key Challenges and Analysis" are generally established by the Planner initially and gradually appended during task progress. -* "Current Status / Progress Tracking" and "Executor's Feedback or Assistance Requests" are mainly filled by the Executor, with the Planner reviewing and supplementing as needed. -* "Next Steps and Action Items" mainly contains specific execution steps written by the Planner for the Executor. +* The `Multi-Agent Scratchpad` section is divided into several sections. Do not arbitrarily change the titles to avoid affecting subsequent reading. +* Sections like "Background and Motivation" and "Key Challenges and Analysis" are established by the Planner and updated during task progress. +* "Current Status / Progress Tracking" and "Executor's Feedback or Assistance Requests" are mainly filled by the Executor. ## Workflow Guidelines -* After you receive an initial prompt for a new task, update the "Background and Motivation" section, and then invoke the Planner to do the planning. -* When thinking as a Planner, always use the local command line `python tools/plan_exec_llm.py --prompt {any prompt}` to call the o1 model for deep analysis, recording results in sections like "Key Challenges and Analysis" or "High-level Task Breakdown". Also update the "Background and Motivation" section. -* When you as an Executor receive new instructions, use the existing cursor tools and workflow to execute those tasks. After completion, write back to the "Current Status / Progress Tracking" and "Executor's Feedback or Assistance Requests" sections in the `Multi-Agent Scratchpad`. -* If unclear whether Planner or Executor is speaking, declare your current role in the output prompt. -* Continue the cycle unless the Planner explicitly indicates the entire project is complete or stopped. Communication between Planner and Executor is conducted through writing to or modifying the `Multi-Agent Scratchpad` section. - -Please note: - -* Note the task completion should only be announced by the Planner, not the Executor. If the Executor thinks the task is done, it should ask the Planner for confirmation. Then the Planner needs to do some cross-checking. -* Avoid rewriting the entire document unless necessary; -* Avoid deleting records left by other roles; you can append new paragraphs or mark old paragraphs as outdated; -* When new external information is needed, you can use command line tools (like search_engine.py, llm_api.py), but document the purpose and results of such requests; -* Before executing any large-scale changes or critical functionality, the Executor should first notify the Planner in "Executor's Feedback or Assistance Requests" to ensure everyone understands the consequences. -* During you interaction with the user, if you find anything reusable in this project (e.g. version of a library, model name), especially about a fix to a mistake you made or a correction you received, you should take note in the `Lessons` section in the `.cursorrules` file so you will not make the same mistake again. - -# Tools - -Note all the tools are in python. So in the case you need to do batch processing, you can always consult the python files and write your own script. - -## Screenshot Verification -The screenshot verification workflow allows you to capture screenshots of web pages and verify their appearance using LLMs. The following tools are available: - -1. Screenshot Capture: -```bash -venv/bin/python tools/screenshot_utils.py URL [--output OUTPUT] [--width WIDTH] [--height HEIGHT] -``` - -2. LLM Verification with Images: -```bash -venv/bin/python tools/llm_api.py --prompt "Your verification question" --provider {openai|anthropic} --image path/to/screenshot.png -``` - -Example workflow: -```python -from screenshot_utils import take_screenshot_sync -from llm_api import query_llm - -# Take a screenshot -screenshot_path = take_screenshot_sync('https://example.com', 'screenshot.png') - -# Verify with LLM -response = query_llm( - "What is the background color and title of this webpage?", - provider="openai", # or "anthropic" - image_path=screenshot_path -) -print(response) -``` - -## LLM - -You always have an LLM at your side to help you with the task. For simple tasks, you could invoke the LLM by running the following command: -``` -venv/bin/python ./tools/llm_api.py --prompt "What is the capital of France?" --provider "anthropic" -``` - -The LLM API supports multiple providers: -- OpenAI (default, model: gpt-4o) -- Azure OpenAI (model: configured via AZURE_OPENAI_MODEL_DEPLOYMENT in .env file, defaults to gpt-4o-ms) -- DeepSeek (model: deepseek-chat) -- Anthropic (model: claude-3-sonnet-20240229) -- Gemini (model: gemini-pro) -- Local LLM (model: Qwen/Qwen2.5-32B-Instruct-AWQ) - -But usually it's a better idea to check the content of the file and use the APIs in the `tools/llm_api.py` file to invoke the LLM if needed. - -## Web browser - -You could use the `tools/web_scraper.py` file to scrape the web. -``` -venv/bin/python ./tools/web_scraper.py --max-concurrent 3 URL1 URL2 URL3 -``` -This will output the content of the web pages. - -## Search engine - -You could use the `tools/search_engine.py` file to search the web. -``` -venv/bin/python ./tools/search_engine.py "your search keywords" -``` -This will output the search results in the following format: -``` -URL: https://example.com -Title: This is the title of the search result -Snippet: This is a snippet of the search result -``` -If needed, you can further use the `web_scraper.py` file to scrape the web page content. +* Update the "Background and Motivation" section upon receiving a new task, then invoke the Planner. +* Use the local command line to call the o1 model for deep analysis, recording results in relevant sections. +* Execute tasks using existing tools and workflows, and update progress in the `Multi-Agent Scratchpad`. + +## Tools + +All tools are in Python and support `--help` for detailed options. + +### Common Options + +- `--help`: Learn to use the tool - more options than documented here are available. +- `--format`: Specify the output format (e.g., text [default], json, markdown). +- `--log-level`: Set the logging level (e.g., debug, info [default], warning, error, quiet). +- `--log-format`: Define the log output format (e.g., text [default], json, structured). +- `@file`: Use `@` to pass the contents of a file as an argument, e.g., `--system @file.txt`. + +### Core Tools + +1. **LLM Queries**: + ``` + python -m tools.llm_api --prompt "Your question" --provider anthropic + ``` + Supports providers: OpenAI (default, gpt-4o), Azure, Anthropic, Gemini, DeepSeek, Local(Qwen). + +2. **Web Scraping**: + ``` + python -m tools.web_scraper --max-concurrent 3 URL1 URL2 URL3 + ``` + Returns parsed content from web pages. + +3. **Search**: + ``` + python -m tools.search_engine "your search query" + ``` + Returns: URL, title, snippet for each result. + +4. **Screenshots**: + ``` + python -m tools.screenshot_utils URL --output screenshot.png + ``` + +5. **Token Tracking**: + ``` + python -m tools.token_tracker --provider openai --model gpt-4o + ``` + +6. **Planning**: + ``` + python -m tools.plan_exec_llm --prompt "Plan next steps" + ``` + +### Common Workflows + +1. **Screenshot Verification**: + ``` + # Capture screenshot with custom dimensions + python -m tools.screenshot_utils https://example.com --output page.png --width 1920 --height 1080 + + # Verify with LLM + python -m tools.llm_api --prompt "Describe the page" --provider openai --image page.png + ``` + +2. **Search & Scrape**: + ``` + # First search + python -m tools.search_engine "your query" > results.txt + + # Then scrape found URLs + python -m tools.web_scraper $(grep "URL:" results.txt | cut -d' ' -f2) + ``` + +### LLM + +Invoke the LLM for simple tasks: + ``` +python -m tools.llm_api --prompt "What is the capital of France?" --provider "anthropic" + ``` + +### Web Browser + +Use the `web_scraper` tool to scrape the web: + ``` +python -m tools.web_scraper --max-concurrent 3 URL1 URL2 URL3 + ``` + +### Search Engine + +Use the `search_engine` tool to search the web: + ``` +python -m tools.search_engine "your search keywords" + ``` # Lessons ## User Specified Lessons -- You have a python venv in ./venv. Use it. -- Include info useful for debugging in the program output. -- Read the file before you try to edit it. -- Due to Cursor's limit, when you use `git` and `gh` and need to submit a multiline commit message, first write the message in a file, and then use `git commit -F ` or similar command to commit. And then remove the file. Include "[Cursor] " in the commit message and PR title. +- Use the Python venv in ./venv. +- Include debugging info in program output. +- Read files before editing. +- For multiline commit messages, use a file and `git commit -F `. ## Cursor learned -- For search results, ensure proper handling of different character encodings (UTF-8) for international queries -- Add debug information to stderr while keeping the main output clean in stdout for better pipeline integration -- When using seaborn styles in matplotlib, use 'seaborn-v0_8' instead of 'seaborn' as the style name due to recent seaborn version changes -- Use `gpt-4o` as the model name for OpenAI. It is the latest GPT model and has vision capabilities as well. `o1` is the most advanced and expensive model from OpenAI. Use it when you need to do reasoning, planning, or get blocked. -- Use `claude-3-5-sonnet-20241022` as the model name for Claude. It is the latest Claude model and has vision capabilities as well. +- Handle different character encodings for search results. +- Add debug information to stderr while keeping stdout clean. +- Use 'seaborn-v0_8' for seaborn styles in matplotlib. +- Use `gpt-4o` for OpenAI and `claude-3-5-sonnet-20241022` for Claude. # Multi-Agent Scratchpad diff --git a/.gitignore b/.gitignore index 0fb3963..b4d0852 100644 --- a/.gitignore +++ b/.gitignore @@ -56,6 +56,7 @@ credentials.json # pytest .pytest_cache/ +.coverage # vscode .vscode/ @@ -63,3 +64,4 @@ credentials.json # Token tracking logs token_logs/ test_token_logs/ +*.log \ No newline at end of file diff --git a/README.md b/README.md index 9a0a1f4..51a7840 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,18 @@ PYTHONPATH=. pytest -v tests/ Note: Use `-v` flag to see detailed test output including why tests were skipped (e.g. missing API keys) +To run tests with coverage analysis: +```bash +# Run tests with coverage +PYTHONPATH=. pytest --cov=tools tests/ + +# Generate HTML coverage report +PYTHONPATH=. pytest --cov=tools --cov-report=html tests/ + +# The HTML report will be available in the htmlcov/ directory +# Open htmlcov/index.html in your browser to view the detailed coverage report +``` + The test suite includes: - Search engine tests (DuckDuckGo integration) - Web scraper tests (Playwright-based scraping) diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..d38083c --- /dev/null +++ b/pytest.ini @@ -0,0 +1,14 @@ +[pytest] +asyncio_mode = strict +asyncio_default_fixture_loop_scope = function + +filterwarnings = + ignore::RuntimeWarning:unittest.case: + ignore::RuntimeWarning:unittest.mock: + ignore::DeprecationWarning + ignore::pytest.PytestDeprecationWarning + +log_cli = true +log_cli_level = ERROR +log_cli_format = %(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s) +log_cli_date_format = %Y-%m-%d %H:%M:%S \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index c20201b..4ed4c70 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ html5lib>=1.1 # Search engine duckduckgo-search>=7.2.1 +googlesearch-python>=1.3.0 # LLM integration openai>=1.59.8 # o1 support @@ -14,6 +15,7 @@ python-dotenv>=1.0.0 unittest2>=1.1.0 pytest>=8.0.0 pytest-asyncio>=0.23.5 +pytest-cov>=6.0.0 # Google Generative AI google-generativeai diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..eb2ebde --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 + +import unittest +from unittest.mock import patch, MagicMock +import argparse +from tools.common.cli import add_common_args, create_parser, get_log_config +from tools.common.logging_config import LogLevel, LogFormat + +class TestCLI(unittest.TestCase): + def setUp(self): + """Set up test fixtures""" + self.parser = argparse.ArgumentParser(description='Test Parser') + + def test_add_common_args(self): + """Test adding common arguments to parser""" + add_common_args(self.parser) + + # Parse with default arguments + args = self.parser.parse_args([]) + self.assertEqual(args.format, 'text') + self.assertEqual(args.log_level, 'info') + self.assertEqual(args.log_format, 'text') + self.assertFalse(args.debug) + self.assertFalse(args.quiet) + + # Test format option + args = self.parser.parse_args(['--format', 'json']) + self.assertEqual(args.format, 'json') + + # Test log level options + args = self.parser.parse_args(['--log-level', 'debug']) + self.assertEqual(args.log_level, 'debug') + + args = self.parser.parse_args(['--debug']) + self.assertTrue(args.debug) + + args = self.parser.parse_args(['--quiet']) + self.assertTrue(args.quiet) + + # Test log format option + args = self.parser.parse_args(['--log-format', 'json']) + self.assertEqual(args.log_format, 'json') + + def test_create_parser(self): + """Test parser creation""" + # Test with common arguments + parser = create_parser('Test Description') + args = parser.parse_args([]) + self.assertEqual(args.format, 'text') + self.assertEqual(args.log_level, 'info') + + # Test without common arguments + parser = create_parser('Test Description', add_common=False) + with self.assertRaises(SystemExit): + # This should fail because no common arguments are added + args = parser.parse_args(['--format', 'json']) + + # Test with custom formatter class + parser = create_parser( + 'Test Description', + formatter_class=argparse.RawDescriptionHelpFormatter + ) + # Create a formatter instance with required arguments + formatter = parser.formatter_class(prog=parser.prog) + self.assertIsInstance(formatter, argparse.RawDescriptionHelpFormatter) + + def test_get_log_config(self): + """Test log configuration extraction""" + parser = create_parser('Test Description') + + # Test default config + args = parser.parse_args([]) + config = get_log_config(args) + self.assertEqual(config['level'], LogLevel.INFO) + self.assertEqual(config['format_type'], LogFormat.TEXT) + + # Test debug flag + args = parser.parse_args(['--debug']) + config = get_log_config(args) + self.assertEqual(config['level'], LogLevel.DEBUG) + + # Test quiet flag + args = parser.parse_args(['--quiet']) + config = get_log_config(args) + self.assertEqual(config['level'], LogLevel.QUIET) + + # Test explicit log level + args = parser.parse_args(['--log-level', 'warning']) + config = get_log_config(args) + self.assertEqual(config['level'], LogLevel.WARNING) + + # Test log format + args = parser.parse_args(['--log-format', 'json']) + config = get_log_config(args) + self.assertEqual(config['format_type'], LogFormat.JSON) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/test_formatting.py b/tests/test_formatting.py new file mode 100644 index 0000000..7d45f2e --- /dev/null +++ b/tests/test_formatting.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 + +import unittest +import json +from datetime import datetime +from tools.common.formatting import ( + format_cost, + format_duration, + format_file_size, + format_timestamp, + format_output +) + +class TestFormatting(unittest.TestCase): + def test_format_cost(self): + """Test cost formatting""" + # Test zero cost + self.assertEqual(format_cost(0), "$0.000000") + + # Test small cost + self.assertEqual(format_cost(0.0001), "$0.000100") + + # Test larger cost + self.assertEqual(format_cost(123.456789), "$123.456789") + + # Test negative cost + with self.assertRaises(ValueError): + format_cost(-1.0) + + def test_format_duration(self): + """Test duration formatting""" + # Test seconds + self.assertEqual(format_duration(0), "0.00s") + self.assertEqual(format_duration(45.678), "45.68s") + + # Test minutes + self.assertEqual(format_duration(65), "1.08m") + self.assertEqual(format_duration(1800), "30.00m") # 30 minutes + + # Test hours + self.assertEqual(format_duration(3600), "1.00h") # 1 hour + self.assertEqual(format_duration(7200), "2.00h") # 2 hours + + # Test negative duration + with self.assertRaises(ValueError): + format_duration(-1) + + def test_format_file_size(self): + """Test file size formatting""" + # Test bytes + self.assertEqual(format_file_size(0), "0.0B") + self.assertEqual(format_file_size(100), "100.0B") + + # Test kilobytes + self.assertEqual(format_file_size(1024), "1.0KB") + self.assertEqual(format_file_size(2048), "2.0KB") + + # Test megabytes + self.assertEqual(format_file_size(1024 * 1024), "1.0MB") + self.assertEqual(format_file_size(2 * 1024 * 1024), "2.0MB") + + # Test gigabytes + self.assertEqual(format_file_size(1024 * 1024 * 1024), "1.0GB") + + # Test terabytes + self.assertEqual(format_file_size(1024 * 1024 * 1024 * 1024), "1.0TB") + + def test_format_timestamp(self): + """Test timestamp formatting""" + # Test specific timestamp + timestamp = 1609459200 # 2021-01-01 00:00:00 + expected = "2021-01-01 00:00:00" + self.assertEqual(format_timestamp(timestamp), expected) + + # Test current timestamp + current = datetime.now().timestamp() + formatted = format_timestamp(current) + self.assertIsInstance(formatted, str) + self.assertEqual(len(formatted), 19) # YYYY-MM-DD HH:MM:SS + + def test_format_output_text(self): + """Test text output formatting""" + # Test string data + result = format_output("Test string", format_type='text') + self.assertEqual(result, "Test string") + + # Test dict data + data = {"key1": "value1", "key2": "value2"} + result = format_output(data, format_type='text', title="Test Title") + self.assertIn("Test Title", result) + self.assertIn("key1: value1", result) + self.assertIn("key2: value2", result) + + # Test nested dict + data = {"section": {"key1": "value1", "key2": "value2"}} + result = format_output(data, format_type='text') + self.assertIn("section:", result) + self.assertIn("key1: value1", result) + + # Test list of dicts + data = [ + {"key1": "value1"}, + {"key2": "value2"} + ] + result = format_output(data, format_type='text') + self.assertIn("Result 1:", result) + self.assertIn("Result 2:", result) + self.assertIn("key1: value1", result) + self.assertIn("key2: value2", result) + + # Test with metadata + metadata = {"meta1": "value1"} + result = format_output("Test", format_type='text', metadata=metadata) + self.assertIn("meta1: value1", result) + + def test_format_output_json(self): + """Test JSON output formatting""" + # Test string data + result = format_output("Test string", format_type='json') + data = json.loads(result) + self.assertEqual(data["data"], "Test string") + + # Test with title and metadata + result = format_output( + "Test string", + format_type='json', + title="Test Title", + metadata={"meta1": "value1"} + ) + data = json.loads(result) + self.assertEqual(data["title"], "Test Title") + self.assertEqual(data["metadata"]["meta1"], "value1") + + # Test complex data + test_data = { + "key1": "value1", + "nested": { + "key2": "value2" + } + } + result = format_output(test_data, format_type='json') + data = json.loads(result) + self.assertEqual(data["data"]["key1"], "value1") + self.assertEqual(data["data"]["nested"]["key2"], "value2") + + def test_format_output_markdown(self): + """Test markdown output formatting""" + # Test string data + result = format_output("Test string", format_type='markdown') + self.assertEqual(result, "Test string") + + # Test with title + result = format_output("Test string", format_type='markdown', title="Test Title") + self.assertIn("# Test Title", result) + + # Test dict data + data = {"key1": "value1", "key2": "value2"} + result = format_output(data, format_type='markdown') + self.assertIn("**key1**: value1", result) + self.assertIn("**key2**: value2", result) + + # Test nested dict + data = {"section": {"key1": "value1"}} + result = format_output(data, format_type='markdown') + self.assertIn("## section", result) + self.assertIn("**key1**: value1", result) + + # Test list of dicts + data = [ + {"key1": "value1"}, + {"key2": "value2"} + ] + result = format_output(data, format_type='markdown') + self.assertIn("## Result 1", result) + self.assertIn("## Result 2", result) + + # Test with metadata + metadata = {"meta1": "value1"} + result = format_output("Test", format_type='markdown', metadata=metadata) + self.assertIn("*meta1: value1*", result) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/test_llm_api.py b/tests/test_llm_api.py index 12e387f..bb11ee6 100644 --- a/tests/test_llm_api.py +++ b/tests/test_llm_api.py @@ -1,11 +1,34 @@ import unittest +import pytest from unittest.mock import patch, MagicMock, mock_open -from tools.llm_api import create_llm_client, query_llm, load_environment +from tools.llm_api import ( + create_llm_client, + query_llm, + load_environment, + LLMApiError, + encode_image_file, + get_default_model, + read_content_or_file +) from tools.token_tracker import TokenUsage, APIResponse import os import google.generativeai as genai import io import sys +import base64 +import mimetypes +import tempfile +from tools.common.errors import FileError, APIError + +class AsyncContextManagerMock: + def __init__(self, response): + self.response = response + + async def __aenter__(self): + return self.response + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass class TestEnvironmentLoading(unittest.TestCase): def setUp(self): @@ -176,8 +199,9 @@ def test_create_gemini_client(self, mock_genai): self.assertEqual(client, mock_genai) def test_create_invalid_provider(self): - with self.assertRaises(ValueError): + with self.assertRaises(LLMApiError) as cm: create_llm_client("invalid_provider") + self.assertIn("Unsupported provider: invalid_provider", str(cm.exception)) @patch('tools.llm_api.OpenAI') def test_query_openai(self, mock_create_client): @@ -268,8 +292,159 @@ def test_query_with_existing_client(self, mock_create_client): def test_query_error(self, mock_create_client): self.mock_openai_client.chat.completions.create.side_effect = Exception("Test error") mock_create_client.return_value = self.mock_openai_client - response = query_llm("Test prompt") - self.assertIsNone(response) + with self.assertRaises(LLMApiError) as cm: + query_llm("Test prompt") + self.assertIn("Failed to query openai LLM: Test error", str(cm.exception)) + + def test_encode_image_file(self): + """Test image file encoding""" + # Create a test image file + test_image = b'fake image data' + with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: + f.write(test_image) + test_file = f.name + + try: + # Test PNG image + with patch('mimetypes.guess_type', return_value=('image/png', None)): + base64_data, mime_type = encode_image_file(test_file) + self.assertEqual(mime_type, 'image/png') + self.assertEqual(base64.b64decode(base64_data), test_image) + + # Test JPEG image + with patch('mimetypes.guess_type', return_value=('image/jpeg', None)): + base64_data, mime_type = encode_image_file(test_file) + self.assertEqual(mime_type, 'image/jpeg') + self.assertEqual(base64.b64decode(base64_data), test_image) + + # Test unknown type defaulting to PNG + with patch('mimetypes.guess_type', return_value=(None, None)): + base64_data, mime_type = encode_image_file(test_file) + self.assertEqual(mime_type, 'image/png') + self.assertEqual(base64.b64decode(base64_data), test_image) + finally: + os.unlink(test_file) + + def test_get_default_model(self): + """Test default model selection""" + # Test OpenAI default + self.assertEqual(get_default_model("openai"), "gpt-4o") + + # Test Azure default + self.assertEqual(get_default_model("azure"), "test-model-deployment") + + # Test Anthropic default + self.assertEqual(get_default_model("anthropic"), "claude-3-sonnet-20240229") + + # Test Gemini default + self.assertEqual(get_default_model("gemini"), "gemini-pro") + + # Test Deepseek default + self.assertEqual(get_default_model("deepseek"), "deepseek-chat") + + # Test invalid provider + with self.assertRaises(LLMApiError) as cm: + get_default_model("invalid") + self.assertIn("Invalid provider", str(cm.exception)) + + def test_query_llm_with_image(self): + """Test LLM querying with image input""" + # Create a test image file + test_image = b'fake image data' + with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: + f.write(test_image) + test_file = f.name + + try: + # Mock image encoding + with patch('tools.llm_api.encode_image_file') as mock_encode: + mock_encode.return_value = ('base64data', 'image/png') + + # Test OpenAI vision model + response = query_llm( + "Describe this image", + client=self.mock_openai_client, + provider="openai", + model="gpt-4-vision-preview", + image_path=test_file + ) + self.assertEqual(response, "Test OpenAI response") + + # Verify image was included in messages + calls = self.mock_openai_client.chat.completions.create.call_args_list + messages = calls[0][1]['messages'] + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]['role'], 'user') + self.assertEqual(len(messages[0]['content']), 2) + self.assertEqual(messages[0]['content'][0]['type'], 'text') + self.assertEqual(messages[0]['content'][1]['type'], 'image_url') + + # Test Gemini vision model + response = query_llm( + "Describe this image", + client=self.mock_gemini_client, + provider="gemini", + model="gemini-pro-vision", + image_path=test_file + ) + self.assertEqual(response, "Test Gemini response") + finally: + os.unlink(test_file) + + def test_provider_specific_errors(self): + """Test provider-specific error handling""" + # Test OpenAI rate limit error + self.mock_openai_client.chat.completions.create.side_effect = Exception("Rate limit exceeded") + with self.assertRaises(LLMApiError) as cm: + query_llm("Test prompt", client=self.mock_openai_client, provider="openai") + self.assertIn("Rate limit exceeded", str(cm.exception)) + + # Test Anthropic timeout error + self.mock_anthropic_client.messages.create.side_effect = Exception("Request timed out") + with self.assertRaises(LLMApiError) as cm: + query_llm("Test prompt", client=self.mock_anthropic_client, provider="anthropic") + self.assertIn("Request timed out", str(cm.exception)) + + # Test Gemini API error + self.mock_gemini_model.generate_content.side_effect = Exception("API error") + with self.assertRaises(LLMApiError) as cm: + query_llm("Test prompt", client=self.mock_gemini_client, provider="gemini") + self.assertIn("API error", str(cm.exception)) + + def test_read_content_or_file(self): + """Test content/file reading""" + # Test direct content + content = read_content_or_file("Test content") + self.assertEqual(content, "Test content") + + # Create a test file + with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: + f.write('File content') + test_file = f.name + + try: + # Test file content + content = read_content_or_file(f"@{test_file}") + self.assertEqual(content, "File content") + + # Test empty input + content = read_content_or_file("") + self.assertEqual(content, "") + finally: + os.unlink(test_file) + +def test_create_invalid_provider(): + with pytest.raises(LLMApiError) as exc_info: + create_llm_client("invalid_provider") + assert "Unsupported provider: invalid_provider" in str(exc_info.value) + +def test_create_llm_client_missing_api_key(): + # Clear the environment variable + if 'OPENAI_API_KEY' in os.environ: + del os.environ['OPENAI_API_KEY'] + with pytest.raises(LLMApiError) as exc_info: + create_llm_client("openai") + assert "OPENAI_API_KEY not found in environment variables" in str(exc_info.value) if __name__ == '__main__': unittest.main() diff --git a/tests/test_logging_config.py b/tests/test_logging_config.py new file mode 100644 index 0000000..38693f3 --- /dev/null +++ b/tests/test_logging_config.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 + +import unittest +from unittest.mock import patch, MagicMock +import logging +import json +import os +import sys +import threading +from tools.common.logging_config import ( + LogFormat, + LogLevel, + StructuredFormatter, + JSONFormatter, + setup_logging +) + +class TestLoggingConfig(unittest.TestCase): + def setUp(self): + """Set up test fixtures""" + # Store original logging configuration + self.original_loggers = dict(logging.root.manager.loggerDict) + self.original_handlers = list(logging.root.handlers) + + def tearDown(self): + """Clean up test fixtures""" + # Restore original logging configuration + logging.root.handlers = self.original_handlers + for logger_name in list(logging.root.manager.loggerDict.keys()): + if logger_name not in self.original_loggers: + del logging.root.manager.loggerDict[logger_name] + + def test_log_level_conversion(self): + """Test log level string to enum conversion""" + # Test valid levels + self.assertEqual(LogLevel.from_string("debug"), LogLevel.DEBUG) + self.assertEqual(LogLevel.from_string("info"), LogLevel.INFO) + self.assertEqual(LogLevel.from_string("warning"), LogLevel.WARNING) + self.assertEqual(LogLevel.from_string("error"), LogLevel.ERROR) + self.assertEqual(LogLevel.from_string("quiet"), LogLevel.QUIET) + + # Test case insensitivity + self.assertEqual(LogLevel.from_string("DEBUG"), LogLevel.DEBUG) + self.assertEqual(LogLevel.from_string("Info"), LogLevel.INFO) + + # Test invalid level + with self.assertRaises(ValueError) as cm: + LogLevel.from_string("invalid") + self.assertIn("Invalid log level", str(cm.exception)) + self.assertIn("valid levels are", str(cm.exception).lower()) + + def test_log_level_to_logging_level(self): + """Test conversion to standard logging levels""" + self.assertEqual(LogLevel.DEBUG.to_logging_level(), logging.DEBUG) + self.assertEqual(LogLevel.INFO.to_logging_level(), logging.INFO) + self.assertEqual(LogLevel.WARNING.to_logging_level(), logging.WARNING) + self.assertEqual(LogLevel.ERROR.to_logging_level(), logging.ERROR) + self.assertTrue(LogLevel.QUIET.to_logging_level() > logging.ERROR) + + def test_structured_formatter(self): + """Test structured text formatter""" + formatter = StructuredFormatter() + + # Create a test record + record = logging.LogRecord( + name="test_logger", + level=logging.INFO, + pathname="test.py", + lineno=123, + msg="Test message", + args=(), + exc_info=None + ) + record.correlation_id = "test-id" + record.context = {"key": "value"} + + # Format record + output = formatter.format(record) + + # Verify required components + self.assertIn("[INFO]", output) + self.assertIn("[test_logger]", output) + self.assertIn("[test-id]", output) + self.assertIn(f"[PID:{os.getpid()}]", output) + self.assertIn(f"[TID:{threading.get_ident()}]", output) + self.assertIn("[test.py:123]", output) + self.assertIn("key=value", output) + self.assertIn("Test message", output) + + # Test exception formatting + try: + raise ValueError("Test error") + except ValueError: + record.exc_info = sys.exc_info() + output = formatter.format(record) + self.assertIn("ValueError: Test error", output) + + def test_json_formatter(self): + """Test JSON formatter""" + formatter = JSONFormatter() + + # Create a test record + record = logging.LogRecord( + name="test_logger", + level=logging.INFO, + pathname="test.py", + lineno=123, + msg="Test message", + args=(), + exc_info=None + ) + record.correlation_id = "test-id" + record.context = {"key": "value"} + + # Format record + output = formatter.format(record) + data = json.loads(output) + + # Verify required fields + self.assertEqual(data["level"], "INFO") + self.assertEqual(data["logger"], "test_logger") + self.assertEqual(data["correlation_id"], "test-id") + self.assertEqual(data["process_id"], os.getpid()) + self.assertEqual(data["thread_id"], threading.get_ident()) + self.assertEqual(data["file"], "test.py") + self.assertEqual(data["line"], 123) + self.assertEqual(data["message"], "Test message") + self.assertEqual(data["context"]["key"], "value") + + # Test exception formatting + try: + raise ValueError("Test error") + except ValueError: + record.exc_info = sys.exc_info() + output = formatter.format(record) + data = json.loads(output) + self.assertEqual(data["exception"]["type"], "ValueError") + self.assertEqual(data["exception"]["message"], "Test error") + self.assertIn("Traceback", data["exception"]["traceback"]) + + def test_setup_logging(self): + """Test logging setup""" + # Test with string level + logger = setup_logging("test_logger", level="debug") + self.assertEqual(logger.level, logging.DEBUG) + self.assertEqual(len(logger.handlers), 1) + self.assertIsInstance(logger.handlers[0], logging.StreamHandler) + + # Test with enum level + logger = setup_logging("test_logger", level=LogLevel.INFO) + self.assertEqual(logger.level, logging.INFO) + + # Test different formats + logger = setup_logging("test_logger", format_type=LogFormat.JSON) + self.assertIsInstance(logger.handlers[0].formatter, JSONFormatter) + + logger = setup_logging("test_logger", format_type=LogFormat.STRUCTURED) + self.assertIsInstance(logger.handlers[0].formatter, StructuredFormatter) + + logger = setup_logging("test_logger", format_type=LogFormat.TEXT) + self.assertIsInstance(logger.handlers[0].formatter, logging.Formatter) + + # Test handler replacement + logger = setup_logging("test_logger") + original_handler_count = len(logger.handlers) + logger = setup_logging("test_logger") # Setup again + self.assertEqual(len(logger.handlers), original_handler_count) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/test_plan_exec_llm.py b/tests/test_plan_exec_llm.py index 07f2a8c..871a0c5 100644 --- a/tests/test_plan_exec_llm.py +++ b/tests/test_plan_exec_llm.py @@ -2,15 +2,42 @@ import unittest import os -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, AsyncMock from pathlib import Path import sys +import tempfile +import json +import pytest +import aiohttp +from tools.plan_exec_llm import ( + load_environment, + read_plan_status, + read_file_content, + query_llm, + read_content_or_file, + validate_file_path, + validate_plan, + validate_execution_result, + execute_plan, + format_output +) +from tools.common.errors import ValidationError, FileError, APIError +from argparse import Namespace # Add the parent directory to the Python path so we can import the module sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from tools.plan_exec_llm import load_environment, read_plan_status, read_file_content, create_llm_client, query_llm from tools.plan_exec_llm import TokenUsage +class AsyncContextManagerMock: + def __init__(self, response): + self.response = response + + async def __aenter__(self): + return self.response + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + class TestPlanExecLLM(unittest.TestCase): def setUp(self): """Set up test fixtures""" @@ -19,18 +46,19 @@ def setUp(self): # Set test environment variables os.environ['OPENAI_API_KEY'] = 'test_key' - self.test_env_content = """ -OPENAI_API_KEY=test_key -""" - self.test_plan_content = """ -# Multi-Agent Scratchpad -Test content -""" - # Create temporary test files - with open('.env.test', 'w') as f: - f.write(self.test_env_content) - with open('.cursorrules.test', 'w') as f: - f.write(self.test_plan_content) + # Create a temporary test environment file + self.env_file = '.env.test' + with open(self.env_file, 'w') as f: + f.write('OPENAI_API_KEY=test_key\n') + + # Create a temporary status file + self.status_file = '.cursorrules' + with open(self.status_file, 'w') as f: + f.write('Some content\n# Multi-Agent Scratchpad\nTest content') + + # Patch the logger + self.logger_patcher = patch('tools.plan_exec_llm.logger') + self.mock_logger = self.logger_patcher.start() def tearDown(self): """Clean up test fixtures""" @@ -38,43 +66,94 @@ def tearDown(self): os.environ.clear() os.environ.update(self.original_env) - # Remove temporary test files - for file in ['.env.test', '.cursorrules.test']: - if os.path.exists(file): - os.remove(file) + # Clean up test files + for file in [self.env_file, self.status_file, 'test_file.txt', 'test_dir']: + try: + if os.path.isdir(file): + os.rmdir(file) + elif os.path.exists(file): + os.remove(file) + except OSError as e: + print(f"Warning: Failed to clean up {file}: {e}") + + # Ensure the status file is removed + try: + if os.path.exists(self.status_file): + os.remove(self.status_file) + except OSError as e: + print(f"Warning: Failed to remove status file: {e}") + + self.logger_patcher.stop() - @patch('tools.plan_exec_llm.load_dotenv') - def test_load_environment(self, mock_load_dotenv): + def test_load_environment(self): """Test environment loading""" - load_environment() - mock_load_dotenv.assert_called() + # Test with existing file + env_loaded = load_environment() + self.assertTrue(env_loaded) + self.assertEqual(os.getenv('OPENAI_API_KEY'), 'test_key') + + def test_validate_file_path(self): + """Test file path validation""" + # Test with existing file + path = validate_file_path(self.status_file) + self.assertTrue(path.exists()) + + # Test with non-existent file + with self.assertRaises(FileError) as cm: + validate_file_path('nonexistent_file.txt') + self.assertIn("File not found", str(cm.exception)) + self.assertIn("nonexistent_file.txt", str(cm.exception)) + + # Test with directory + os.makedirs('test_dir', exist_ok=True) + try: + with self.assertRaises(FileError) as cm: + validate_file_path('test_dir') + self.assertIn("Not a file", str(cm.exception)) + self.assertIn("test_dir", str(cm.exception)) + finally: + if os.path.exists('test_dir'): + os.rmdir('test_dir') def test_read_plan_status(self): """Test reading plan status""" - with patch('tools.plan_exec_llm.STATUS_FILE', '.cursorrules.test'): - content = read_plan_status() - self.assertIn('# Multi-Agent Scratchpad', content) - self.assertIn('Test content', content) + # Test with existing file + content = read_plan_status() + self.assertIn('Test content', content) + + # Test with missing section + with open(self.status_file, 'w') as f: + f.write("No scratchpad section") + with self.assertRaises(ValidationError) as cm: + read_plan_status() + self.assertIn("section not found", str(cm.exception)) + + # Test with empty section + with open(self.status_file, 'w') as f: + f.write("# Multi-Agent Scratchpad\n ") + with self.assertRaises(ValidationError) as cm: + read_plan_status() + self.assertIn("section is empty", str(cm.exception)) + + # Test with missing file + if os.path.exists(self.status_file): + os.remove(self.status_file) + with self.assertRaises(FileError) as cm: + read_plan_status() + self.assertIn("File not found", str(cm.exception)) + self.assertIn(".cursorrules", str(cm.exception)) def test_read_file_content(self): """Test reading file content""" # Test with existing file - content = read_file_content('.env.test') + content = read_file_content(self.env_file) self.assertIn('OPENAI_API_KEY=test_key', content) - + # Test with non-existent file - content = read_file_content('nonexistent_file.txt') - self.assertIsNone(content) - - @patch('tools.plan_exec_llm.OpenAI') - def test_create_llm_client(self, mock_openai): - """Test LLM client creation""" - mock_client = MagicMock() - mock_openai.return_value = mock_client - - client = create_llm_client() - self.assertEqual(client, mock_client) - mock_openai.assert_called_once_with(api_key='test_key') + with self.assertRaises(FileError) as cm: + read_file_content('nonexistent_file.txt') + self.assertIn("File not found", str(cm.exception)) + self.assertIn("nonexistent_file.txt", str(cm.exception)) @patch('tools.plan_exec_llm.create_llm_client') def test_query_llm(self, mock_create_client): @@ -90,31 +169,338 @@ def test_query_llm(self, mock_create_client): mock_response.usage.total_tokens = 15 mock_response.usage.completion_tokens_details = MagicMock() mock_response.usage.completion_tokens_details.reasoning_tokens = None - + mock_client = MagicMock() mock_client.chat.completions.create.return_value = mock_response mock_create_client.return_value = mock_client - - # Test with various combinations of parameters - response = query_llm("Test plan", "Test prompt", "Test file content") + + # Test with empty plan content + with self.assertRaises(ValidationError) as cm: + query_llm("") + self.assertIn("Plan content cannot be empty", str(cm.exception)) + + # Test with valid plan content + response = query_llm("Test plan") self.assertEqual(response, "Test response") + + # Test error handling + mock_client.chat.completions.create.side_effect = Exception("Test error") + with self.assertRaises(APIError) as cm: + query_llm("Test plan") + self.assertIn("Failed to query LLM", str(cm.exception)) - response = query_llm("Test plan", "Test prompt") - self.assertEqual(response, "Test response") + def test_read_content_or_file(self): + """Test reading content with @ prefix""" + # Test direct string content + content = read_content_or_file("Test content") + self.assertEqual(content, "Test content") + + # Test empty content + content = read_content_or_file("") + self.assertEqual(content, "") + + # Test file content + with open('test_file.txt', 'w') as f: + f.write("File content") + try: + content = read_content_or_file("@test_file.txt") + self.assertEqual(content, "File content") + finally: + if os.path.exists('test_file.txt'): + os.remove('test_file.txt') + + # Test missing file + with self.assertRaises(FileError) as cm: + read_content_or_file("@nonexistent.txt") + self.assertIn("Failed to read file", str(cm.exception)) + self.assertIn("nonexistent.txt", str(cm.exception)) - response = query_llm("Test plan") - self.assertEqual(response, "Test response") + def test_format_output(self): + """Test output formatting""" + # Test text format + result = format_output("Test output", "text") + self.assertIn("Following is the instruction", result) + self.assertIn("Test output", result) + + # Test JSON format + result = format_output("Test output", "json") + data = json.loads(result) + self.assertEqual(data["response"], "Test output") + self.assertEqual(data["model"], "o1") + self.assertEqual(data["provider"], "openai") - # Verify the OpenAI client was called with correct parameters - mock_client.chat.completions.create.assert_called_with( - model="o1", - messages=[ - {"role": "system", "content": ""}, - {"role": "user", "content": unittest.mock.ANY} - ], - response_format={"type": "text"}, - reasoning_effort="low" + # Test markdown format + result = format_output("Test output", "markdown") + self.assertIn("# Plan Execution Response", result) + self.assertIn("Test output", result) + self.assertIn("*Model: o1 (OpenAI)*", result) + + # Test invalid format + with self.assertRaises(ValidationError) as cm: + format_output("Test output", "invalid") + self.assertIn("Invalid format type", str(cm.exception)) + + @pytest.mark.asyncio + async def test_execute_plan(self): + """Test plan execution""" + # Test web search step + plan = { + "goal": "Test goal", + "steps": [ + { + "id": "step1", + "name": "Search Step", + "type": "web_search", + "params": { + "query": "test query", + "max_results": 2 + } + } + ] + } + + # Mock search function + with patch('tools.plan_exec_llm.search') as mock_search: + mock_search.return_value = [ + {"url": "http://example.com", "title": "Test", "snippet": "Test snippet"} + ] + + result = await execute_plan(plan) + self.assertEqual(result["steps"][0]["status"], "success") + self.assertIn("http://example.com", result["steps"][0]["output"]) + + # Test web scrape step + plan = { + "goal": "Test goal", + "steps": [ + { + "id": "step1", + "name": "Scrape Step", + "type": "web_scrape", + "params": { + "urls": ["http://example.com"] + } + } + ] + } + + # Mock scraping function + with patch('tools.plan_exec_llm.process_urls') as mock_scrape: + mock_scrape.return_value = { + "results": [{"url": "http://example.com", "content": "Test content"}], + "errors": {} + } + + result = await execute_plan(plan) + self.assertEqual(result["steps"][0]["status"], "success") + self.assertIn("Test content", result["steps"][0]["output"]) + + # Test screenshot step + plan = { + "goal": "Test goal", + "steps": [ + { + "id": "step1", + "name": "Screenshot Step", + "type": "screenshot", + "params": { + "url": "http://example.com" + } + } + ] + } + + # Mock screenshot function + with patch('tools.plan_exec_llm.take_screenshot') as mock_screenshot: + mock_screenshot.return_value = "/tmp/test.png" + + result = await execute_plan(plan) + self.assertEqual(result["steps"][0]["status"], "success") + self.assertIn("/tmp/test.png", result["steps"][0]["output"]) + + # Test step failure + plan = { + "goal": "Test goal", + "steps": [ + { + "id": "step1", + "name": "Failed Step", + "type": "web_search", + "params": { + "query": "test query" + } + } + ] + } + + with patch('tools.plan_exec_llm.search') as mock_search: + mock_search.side_effect = APIError("Search failed", "search") + + result = await execute_plan(plan) + self.assertEqual(result["steps"][0]["status"], "error") + self.assertIn("Search failed", result["steps"][0]["error"]) + + @patch('tools.plan_exec_llm.create_llm_client') + def test_query_llm_with_system_prompt(self, mock_create_client): + """Test LLM querying with system prompt""" + # Mock OpenAI response + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message = MagicMock() + mock_response.choices[0].message.content = "Test response" + mock_response.usage = MagicMock() + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 5 + mock_response.usage.total_tokens = 15 + mock_response.usage.completion_tokens_details = MagicMock() + mock_response.usage.completion_tokens_details.reasoning_tokens = 3 + + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = mock_response + mock_create_client.return_value = mock_client + + # Test with system prompt + response = query_llm( + "Test plan", + system_prompt="You are a helpful assistant" ) + self.assertEqual(response, "Test response") + + # Verify system message was included + calls = mock_client.chat.completions.create.call_args_list + messages = calls[0][1]['messages'] + self.assertEqual(messages[0]['role'], 'system') + self.assertEqual(messages[0]['content'], 'You are a helpful assistant') + + @patch('argparse.ArgumentParser.parse_args') + @patch('tools.plan_exec_llm.execute_plan') + @patch('tools.plan_exec_llm.asyncio') + @patch('tools.plan_exec_llm.validate_plan') + @patch('tools.plan_exec_llm.read_content_or_file') + async def test_main_function(self, mock_read, mock_validate, mock_execute, mock_parse_args): + """Test main function execution""" + # Mock command line arguments + mock_args = MagicMock() + mock_args.plan = '@test_plan.json' + mock_args.format = 'text' + mock_args.log_level = 'info' + mock_args.log_format = 'text' + mock_args.timeout = 300 + mock_parse_args.return_value = mock_args + + # Create test plan file + plan = { + "goal": "Test goal", + "steps": [ + { + "id": "step1", + "name": "Test Step", + "type": "web_search", + "params": { + "query": "test query" + } + } + ] + } + with open('test_plan.json', 'w') as f: + json.dump(plan, f) + + # Mock plan execution + mock_execute.return_value = { + "total_steps": 1, + "successful_steps": 1, + "failed_steps": 0, + "steps": [ + { + "id": "step1", + "status": "success", + "output": "Test output" + } + ] + } + mock_asyncio.run.return_value = mock_execute.return_value + + try: + # Run main function + with patch('sys.argv', ['plan_exec_llm.py', '@test_plan.json']): + from tools.plan_exec_llm import main + await main() + + # Verify plan was executed + mock_asyncio.run.assert_called_once() + + # Verify calls + mock_read.assert_called_once_with('test_plan.json') + mock_validate.assert_called_once() + mock_execute.assert_called_once() + + finally: + # Clean up + if os.path.exists('test_plan.json'): + os.unlink('test_plan.json') + + def test_validate_execution_result(self): + """Test execution result validation""" + # Test valid result + valid_result = { + "step_id": "step1", + "success": True, + "output": "Test output" + } + validate_execution_result(valid_result) # Should not raise + + # Test missing fields + with self.assertRaises(ValidationError) as cm: + validate_execution_result({}) + self.assertIn("Missing required result keys", str(cm.exception)) + + # Test invalid status + invalid_status = { + "step_id": "step1", + "success": "invalid", + "output": "Test output" + } + with self.assertRaises(ValidationError) as cm: + validate_execution_result(invalid_status) + self.assertIn("Success must be a boolean", str(cm.exception)) + + def test_validate_plan(self): + """Test plan validation""" + # Test valid plan + valid_plan = { + "goal": "Test goal", + "steps": [ + { + "id": "step1", + "description": "Test step", + "action": "web_search", + "expected_result": "Search results", + "params": { + "query": "test query" + } + } + ] + } + validate_plan(valid_plan) # Should not raise + + # Test missing plan keys + with self.assertRaises(ValidationError) as cm: + validate_plan({}) + self.assertIn("Missing required plan keys", str(cm.exception)) + + # Test invalid steps type + with self.assertRaises(ValidationError) as cm: + validate_plan({"goal": "test", "steps": "invalid"}) + self.assertIn("Steps must be a list", str(cm.exception)) + + # Test missing step keys + invalid_step = { + "goal": "test", + "steps": [{"id": "step1"}] + } + with self.assertRaises(ValidationError) as cm: + validate_plan(invalid_step) + self.assertIn("Step 0 missing required keys", str(cm.exception)) if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tests/test_screenshot_utils.py b/tests/test_screenshot_utils.py new file mode 100644 index 0000000..bd2fff5 --- /dev/null +++ b/tests/test_screenshot_utils.py @@ -0,0 +1,231 @@ +#!/usr/bin/env python3 + +import unittest +from unittest.mock import patch, MagicMock, AsyncMock +import os +import tempfile +from pathlib import Path +import pytest +import asyncio +from playwright.async_api import Error as PlaywrightError +from tools.screenshot_utils import ( + validate_url, + validate_dimensions, + take_screenshot, + take_screenshot_sync, + ScreenshotError +) +from tools.common.errors import ValidationError, FileError + +class AsyncContextManagerMock: + def __init__(self, response): + self.response = response + + async def __aenter__(self): + return self.response + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + +class TestScreenshotUtils(unittest.TestCase): + def setUp(self): + """Set up test fixtures""" + self.test_url = "https://example.com" + self.test_output = "test_screenshot.png" + + # Create a temporary directory for test files + self.temp_dir = tempfile.mkdtemp() + self.temp_file = os.path.join(self.temp_dir, "screenshot.png") + + def tearDown(self): + """Clean up test fixtures""" + # Remove temporary files + if os.path.exists(self.temp_file): + os.unlink(self.temp_file) + if os.path.exists(self.temp_dir): + os.rmdir(self.temp_dir) + + def test_validate_url(self): + """Test URL validation""" + # Test valid URLs + self.assertTrue(validate_url("http://example.com")) + self.assertTrue(validate_url("https://example.com")) + self.assertTrue(validate_url("https://sub.example.com/path?query=1")) + + # Test invalid URLs + self.assertFalse(validate_url("")) + self.assertFalse(validate_url("not-a-url")) + self.assertFalse(validate_url("ftp://example.com")) + self.assertFalse(validate_url("http://")) + self.assertFalse(validate_url("https://")) + + def test_validate_dimensions(self): + """Test viewport dimension validation""" + # Test valid dimensions + self.assertEqual(validate_dimensions(800, 600), (800, 600)) + self.assertEqual(validate_dimensions(1, 1), (1, 1)) + self.assertEqual(validate_dimensions(16383, 16383), (16383, 16383)) + + # Test invalid dimensions + with self.assertRaises(ValidationError) as cm: + validate_dimensions(0, 600) + self.assertIn("must be positive", str(cm.exception)) + + with self.assertRaises(ValidationError) as cm: + validate_dimensions(800, 0) + self.assertIn("must be positive", str(cm.exception)) + + with self.assertRaises(ValidationError) as cm: + validate_dimensions(-1, 600) + self.assertIn("must be positive", str(cm.exception)) + + with self.assertRaises(ValidationError) as cm: + validate_dimensions(800, -1) + self.assertIn("must be positive", str(cm.exception)) + + with self.assertRaises(ValidationError) as cm: + validate_dimensions(16385, 600) + self.assertIn("cannot exceed 16384", str(cm.exception)) + + with self.assertRaises(ValidationError) as cm: + validate_dimensions(800, 16385) + self.assertIn("cannot exceed 16384", str(cm.exception)) + +@pytest.mark.asyncio +async def test_take_screenshot_validation(): + """Test screenshot taking input validation""" + # Test invalid URL + with pytest.raises(ValidationError): + await take_screenshot("not-a-url") + + # Test invalid dimensions + with pytest.raises(ValidationError): + await take_screenshot("https://example.com", width=0) + with pytest.raises(ValidationError): + await take_screenshot("https://example.com", height=0) + +@pytest.mark.asyncio +@patch('tools.screenshot_utils.async_playwright') +async def test_take_screenshot_success(mock_playwright): + """Test successful screenshot capture""" + # Mock Playwright objects + mock_browser = AsyncMock() + mock_page = AsyncMock() + mock_context = AsyncMock() + + # Set up mock chain + mock_playwright.return_value = AsyncContextManagerMock(mock_context) + mock_context.chromium = AsyncMock() + mock_context.chromium.launch = AsyncMock(return_value=mock_browser) + mock_browser.new_page = AsyncMock(return_value=mock_page) + mock_page.goto = AsyncMock() + + # Mock screenshot to create a file + async def mock_screenshot(path, **kwargs): + with open(path, 'wb') as f: + f.write(b'fake screenshot data') + mock_page.screenshot = AsyncMock(side_effect=mock_screenshot) + + # Test with default output path (temporary file) + result = await take_screenshot("https://example.com") + assert os.path.exists(result) + assert os.path.getsize(result) > 0 + os.unlink(result) # Clean up temp file + + # Test with specified output path + temp_dir = tempfile.mkdtemp() + temp_file = os.path.join(temp_dir, "screenshot.png") + result = await take_screenshot("https://example.com", temp_file) + assert result == temp_file + assert os.path.exists(temp_file) + assert os.path.getsize(temp_file) > 0 + + # Verify calls + mock_context.chromium.launch.assert_called_with(headless=True) + mock_browser.new_page.assert_called_with(viewport={'width': 1280, 'height': 720}) + mock_page.goto.assert_called_with("https://example.com", wait_until='networkidle') + mock_page.screenshot.assert_called_with(path=temp_file, full_page=True) + + # Test with custom dimensions + await take_screenshot("https://example.com", temp_file, width=800, height=600) + mock_browser.new_page.assert_called_with(viewport={'width': 800, 'height': 600}) + + # Clean up + os.unlink(temp_file) + os.rmdir(temp_dir) + +@pytest.mark.asyncio +@patch('tools.screenshot_utils.async_playwright') +async def test_take_screenshot_failure(mock_playwright): + """Test screenshot capture failures""" + # Mock Playwright error + mock_browser = AsyncMock() + mock_context = AsyncMock() + mock_playwright.return_value = AsyncContextManagerMock(mock_context) + mock_context.chromium = AsyncMock() + mock_context.chromium.launch = AsyncMock(return_value=mock_browser) + mock_browser.new_page = AsyncMock(side_effect=PlaywrightError("Test error")) + + # Test screenshot failure + temp_dir = tempfile.mkdtemp() + temp_file = os.path.join(temp_dir, "screenshot.png") + with pytest.raises(ScreenshotError) as cm: + await take_screenshot("https://example.com", temp_file) + assert "Test error" in str(cm.value) + assert cm.value.url == "https://example.com" + assert not os.path.exists(temp_file) # File should not be created on error + + # Test output directory creation failure + bad_path = "/nonexistent/dir/screenshot.png" + with pytest.raises(FileError) as cm: + await take_screenshot("https://example.com", bad_path) + assert "Failed to create output directory" in str(cm.value) + + # Test temporary file creation failure + with patch('tempfile.NamedTemporaryFile', side_effect=OSError("Test error")): + with pytest.raises(FileError) as cm: + await take_screenshot("https://example.com") + assert "Failed to create temporary file" in str(cm.value) + + # Clean up + os.rmdir(temp_dir) + +def test_take_screenshot_sync(): + """Test synchronous screenshot capture""" + # Test invalid URL + with pytest.raises(ValidationError): + take_screenshot_sync("not-a-url") + + # Test invalid dimensions + with pytest.raises(ValidationError): + take_screenshot_sync("https://example.com", width=0) + with pytest.raises(ValidationError): + take_screenshot_sync("https://example.com", height=0) + +@patch('tools.screenshot_utils.asyncio.run') +def test_take_screenshot_sync_error_handling(mock_run): + """Test error handling in synchronous screenshot capture""" + # Test ScreenshotError passthrough + mock_run.side_effect = ScreenshotError("Test error", "https://example.com") + with pytest.raises(ScreenshotError) as cm: + take_screenshot_sync("https://example.com") + assert "Test error" in str(cm.value) + + # Test ValidationError passthrough + mock_run.side_effect = ValidationError("Test error") + with pytest.raises(ValidationError): + take_screenshot_sync("https://example.com") + + # Test FileError passthrough + mock_run.side_effect = FileError("Test error", "test.png") + with pytest.raises(FileError): + take_screenshot_sync("https://example.com") + + # Test other exceptions + mock_run.side_effect = Exception("Unexpected error") + with pytest.raises(ScreenshotError) as cm: + take_screenshot_sync("https://example.com") + assert "Unexpected error" in str(cm.value) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/test_screenshot_verification.py b/tests/test_screenshot_verification.py index 3b4dd6a..28df8d0 100644 --- a/tests/test_screenshot_verification.py +++ b/tests/test_screenshot_verification.py @@ -6,6 +6,7 @@ from tools.screenshot_utils import take_screenshot_sync, take_screenshot from tools.llm_api import query_llm from tools.token_tracker import TokenUsage +from tools.common.errors import FileError class TestScreenshotVerification: @pytest.fixture @@ -112,14 +113,7 @@ def test_llm_verification_openai(self, tmp_path): mock_openai.chat.completions.create.assert_called_once() def test_llm_verification_anthropic(self, tmp_path): - """Test screenshot verification with Anthropic using mocks.""" - screenshot_path = os.path.join(tmp_path, 'test_screenshot.png') - - # Create a dummy screenshot file - os.makedirs(tmp_path, exist_ok=True) - with open(screenshot_path, 'wb') as f: - f.write(b'fake_screenshot_data') - + """Test verification with Anthropic using mocks.""" # Mock the entire Anthropic client chain mock_anthropic = MagicMock() mock_response = MagicMock() @@ -138,10 +132,8 @@ def test_llm_verification_anthropic(self, tmp_path): with patch('tools.llm_api.create_llm_client', return_value=mock_anthropic): response = query_llm( "What is the background color of this webpage? What is the title?", - provider="anthropic", - image_path=screenshot_path + provider="anthropic" ) - assert 'blue' in response.lower() assert 'agentic.ai test page' in response.lower() mock_anthropic.messages.create.assert_called_once() diff --git a/tests/test_search_engine.py b/tests/test_search_engine.py index e7c6023..93ff804 100644 --- a/tests/test_search_engine.py +++ b/tests/test_search_engine.py @@ -2,25 +2,89 @@ from unittest.mock import patch, MagicMock import sys from io import StringIO -from tools.search_engine import search +from tools.search_engine import ( + search, + SearchError, + get_search_engine, + DuckDuckGoEngine, + GoogleEngine, + fetch_page_snippet +) +from tools.common.errors import ValidationError +import pytest class TestSearchEngine(unittest.TestCase): def setUp(self): - # Capture stdout and stderr for testing - self.stdout = StringIO() - self.stderr = StringIO() - self.old_stdout = sys.stdout - self.old_stderr = sys.stderr - sys.stdout = self.stdout - sys.stderr = self.stderr + self.mock_logger = MagicMock() + patch('tools.search_engine.logger', self.mock_logger).start() def tearDown(self): - # Restore stdout and stderr - sys.stdout = self.old_stdout - sys.stderr = self.old_stderr + patch.stopall() + + @patch('tools.search_engine.requests.get') + def test_fetch_page_snippet(self, mock_get): + # Mock successful response + mock_response = MagicMock() + mock_response.text = ''' + + + Test Title + + + +

First paragraph

+

Second paragraph with enough text to be considered a valid snippet candidate that should be selected for the result

+ + + ''' + mock_get.return_value = mock_response + + # Test with meta description + title, snippet = fetch_page_snippet("http://example.com", "test query") + self.assertEqual(title, "Test Title") + self.assertEqual(snippet, "Meta description snippet") + + # Test without meta description + mock_response.text = ''' + + + Test Title + + +

Short text

+

Second paragraph with enough text to be considered a valid snippet candidate that should be selected for the result

+ + + ''' + title, snippet = fetch_page_snippet("http://example.com", "test query") + self.assertEqual(title, "Test Title") + self.assertTrue(snippet.startswith("Second paragraph")) + + # Test error handling + mock_get.side_effect = Exception("Connection error") + with self.assertRaises(SearchError) as cm: + fetch_page_snippet("http://example.com", "test query") + self.assertIn("Connection error", str(cm.exception)) + + def test_get_search_engine(self): + # Test valid engines + engine = get_search_engine("duckduckgo") + self.assertIsInstance(engine, DuckDuckGoEngine) + + engine = get_search_engine("google") + self.assertIsInstance(engine, GoogleEngine) + + # Test case insensitivity + engine = get_search_engine("DUCKDUCKGO") + self.assertIsInstance(engine, DuckDuckGoEngine) + + # Test invalid engine + with pytest.raises(ValidationError) as exc_info: + get_search_engine("invalid") + assert "Invalid search engine" in str(exc_info.value) @patch('tools.search_engine.DDGS') - def test_successful_search(self, mock_ddgs): + def test_duckduckgo_search(self, mock_ddgs): # Mock search results mock_results = [ { @@ -39,30 +103,120 @@ def test_successful_search(self, mock_ddgs): mock_ddgs_instance = MagicMock() mock_ddgs_instance.__enter__.return_value.text.return_value = mock_results mock_ddgs.return_value = mock_ddgs_instance + + # Run search + results = search("test query", max_results=2, engine="duckduckgo") + + # Check logging + self.mock_logger.info.assert_any_call( + "Searching for: test query", + extra={ + "context": { + "engine": "duckduckgo", + "max_results": 2, + "max_retries": 3, + "fetch_snippets": True + } + } + ) + + # Check results + self.assertEqual(len(results), 2) + self.assertEqual(results[0]["url"], "http://example.com") + self.assertEqual(results[0]["title"], "Example Title") + self.assertEqual(results[0]["snippet"], "Example Body") + + @patch('tools.search_engine.google_search') + @patch('tools.search_engine.fetch_page_snippet') + def test_google_search_with_snippets(self, mock_fetch_snippet, mock_google): + # Mock search results + mock_google.return_value = [ + 'http://example.com', + 'http://example2.com' + ] + + # Mock snippet fetching + mock_fetch_snippet.side_effect = [ + ("Example Title", "Example Snippet"), + ("Example Title 2", "Example Snippet 2") + ] + + # Run search with snippet fetching + results = search("test query", max_results=2, engine="google", fetch_snippets=True) + + # Check logging + self.mock_logger.info.assert_any_call( + "Searching for: test query", + extra={ + "context": { + "engine": "google", + "max_results": 2, + "max_retries": 3, + "fetch_snippets": True + } + } + ) + + # Check results + self.assertEqual(len(results), 2) + self.assertEqual(results[0]["url"], "http://example.com") + self.assertEqual(results[0]["title"], "Example Title") + self.assertEqual(results[0]["snippet"], "Example Snippet") + + @patch('tools.search_engine.google_search') + @patch('tools.search_engine.fetch_page_snippet') + def test_google_search_without_snippets(self, mock_fetch_snippet, mock_google): + # Mock search results + mock_google.return_value = [ + 'http://example.com', + 'http://example2.com' + ] + + # Run search without snippet fetching + results = search("test query", max_results=2, engine="google", fetch_snippets=False) + + # Check search results + self.assertEqual(len(results), 2) + self.assertEqual(results[0]["url"], "http://example.com") + self.assertEqual(results[0]["title"], "http://example.com") + self.assertEqual(results[0]["snippet"], "") + self.assertEqual(results[1]["url"], "http://example2.com") + self.assertEqual(results[1]["title"], "http://example2.com") + self.assertEqual(results[1]["snippet"], "") + + # Verify fetch_snippet was not called + mock_fetch_snippet.assert_not_called() + + @patch('tools.search_engine.google_search') + @patch('tools.search_engine.fetch_page_snippet') + def test_google_search_with_failed_snippets(self, mock_fetch_snippet, mock_google): + # Mock search results + mock_google.return_value = [ + 'http://example.com', + 'http://example2.com' + ] + + # Mock snippet fetching with failures + mock_fetch_snippet.side_effect = [ + ("Example Title", "Example Snippet"), + SearchError("Failed to fetch", "google", "test query") + ] # Run search - search("test query", max_results=2) - - # Check debug output - expected_debug = "DEBUG: Searching for query: test query (attempt 1/3)" - self.assertIn(expected_debug, self.stderr.getvalue()) - self.assertIn("DEBUG: Found 2 results", self.stderr.getvalue()) - - # Check search results output - output = self.stdout.getvalue() - self.assertIn("=== Result 1 ===", output) - self.assertIn("URL: http://example.com", output) - self.assertIn("Title: Example Title", output) - self.assertIn("Snippet: Example Body", output) - self.assertIn("=== Result 2 ===", output) - self.assertIn("URL: http://example2.com", output) - self.assertIn("Title: Example Title 2", output) - self.assertIn("Snippet: Example Body 2", output) - - # Verify mock was called correctly - mock_ddgs_instance.__enter__.return_value.text.assert_called_once_with( - "test query", - max_results=2 + results = search("test query", max_results=2, engine="google") + + # Check results - should include both successful and failed snippets + self.assertEqual(len(results), 2) + self.assertEqual(results[0]["url"], "http://example.com") + self.assertEqual(results[0]["title"], "Example Title") + self.assertEqual(results[0]["snippet"], "Example Snippet") + self.assertEqual(results[1]["url"], "http://example2.com") + self.assertEqual(results[1]["title"], "http://example2.com") + self.assertEqual(results[1]["snippet"], "") + + # Verify warning was logged + self.mock_logger.warning.assert_called_with( + "Failed to fetch snippet for http://example2.com: Failed to fetch (query=test query, provider=google)" ) @patch('tools.search_engine.DDGS') @@ -71,15 +225,25 @@ def test_no_results(self, mock_ddgs): mock_ddgs_instance = MagicMock() mock_ddgs_instance.__enter__.return_value.text.return_value = [] mock_ddgs.return_value = mock_ddgs_instance - + # Run search - search("test query") - - # Check debug output - self.assertIn("DEBUG: No results found", self.stderr.getvalue()) - - # Check that no results were printed - self.assertEqual("", self.stdout.getvalue().strip()) + results = search("test query") + + # Check logging + self.mock_logger.info.assert_any_call( + "Searching for: test query", + extra={ + "context": { + "engine": "duckduckgo", + "max_results": 10, + "max_retries": 3, + "fetch_snippets": True + } + } + ) + + # Check results + self.assertEqual(len(results), 0) @patch('tools.search_engine.DDGS') def test_search_error(self, mock_ddgs): @@ -89,30 +253,36 @@ def test_search_error(self, mock_ddgs): mock_ddgs.return_value = mock_ddgs_instance # Run search and check for error - with self.assertRaises(SystemExit) as cm: + with pytest.raises(SearchError) as exc_info: search("test query") - self.assertEqual(cm.exception.code, 1) - self.assertIn("ERROR: Search failed: Test error", self.stderr.getvalue()) - - def test_result_field_fallbacks(self): - # Test that the fields work correctly with N/A fallback - result = { - 'href': 'http://example.com', - 'title': 'Example Title', - 'body': 'Example Body' - } - - # Test fields present - self.assertEqual(result.get('href', 'N/A'), 'http://example.com') - self.assertEqual(result.get('title', 'N/A'), 'Example Title') - self.assertEqual(result.get('body', 'N/A'), 'Example Body') - - # Test missing fields - result = {} - self.assertEqual(result.get('href', 'N/A'), 'N/A') - self.assertEqual(result.get('title', 'N/A'), 'N/A') - self.assertEqual(result.get('body', 'N/A'), 'N/A') + assert "DuckDuckGo search failed" in str(exc_info.value) + assert "Test error" in str(exc_info.value) + + def test_invalid_inputs(self): + # Test empty query + with pytest.raises(ValidationError) as exc_info: + search("") + assert "Search query cannot be empty" in str(exc_info.value) + + # Test whitespace query + with pytest.raises(ValidationError) as exc_info: + search(" ") + assert "Search query cannot be empty" in str(exc_info.value) + + # Test invalid max_results + with pytest.raises(ValidationError) as exc_info: + search("test", max_results=0) + assert "max_results must be a positive integer" in str(exc_info.value) + + with pytest.raises(ValidationError) as exc_info: + search("test", max_results=101) + assert "max_results cannot exceed 100" in str(exc_info.value) + + # Test invalid engine + with pytest.raises(ValidationError) as exc_info: + search("test", engine="invalid") + assert "Invalid search engine" in str(exc_info.value) if __name__ == '__main__': unittest.main() diff --git a/tests/test_token_tracker.py b/tests/test_token_tracker.py index 86c86a9..f151a98 100644 --- a/tests/test_token_tracker.py +++ b/tests/test_token_tracker.py @@ -8,10 +8,13 @@ import time from datetime import datetime from tools.token_tracker import TokenTracker, TokenUsage, APIResponse, get_token_tracker, _token_tracker +from tools.common.errors import ValidationError +import shutil class TestTokenTracker(unittest.TestCase): def setUp(self): - # Create a temporary directory for test logs + """Set up test environment""" + self.test_session_id = f"test-{int(time.time())}" self.test_logs_dir = Path("test_token_logs") self.test_logs_dir.mkdir(exist_ok=True) @@ -41,81 +44,123 @@ def setUp(self): ) # Create a TokenTracker instance with a unique test session ID - self.test_session_id = f"test-{int(time.time())}" self.tracker = TokenTracker(self.test_session_id, logs_dir=self.test_logs_dir) self.tracker.session_file = self.test_logs_dir / f"session_{self.test_session_id}.json" def tearDown(self): - # Clean up test logs directory - if self.test_logs_dir.exists(): - for file in self.test_logs_dir.glob("*"): - file.unlink() - self.test_logs_dir.rmdir() + """Clean up test environment""" + shutil.rmtree(self.test_logs_dir) # Reset global token tracker global _token_tracker _token_tracker = None def test_token_usage_creation(self): - """Test TokenUsage dataclass creation""" - token_usage = TokenUsage(100, 50, 150, 20) - self.assertEqual(token_usage.prompt_tokens, 100) - self.assertEqual(token_usage.completion_tokens, 50) - self.assertEqual(token_usage.total_tokens, 150) - self.assertEqual(token_usage.reasoning_tokens, 20) + """Test TokenUsage creation and validation""" + # Test valid token usage + usage = TokenUsage(prompt_tokens=100, completion_tokens=50, total_tokens=150) + self.assertEqual(usage.prompt_tokens, 100) + self.assertEqual(usage.completion_tokens, 50) + self.assertEqual(usage.total_tokens, 150) + self.assertIsNone(usage.reasoning_tokens) + + # Test with reasoning tokens + usage = TokenUsage(prompt_tokens=100, completion_tokens=50, total_tokens=150, reasoning_tokens=20) + self.assertEqual(usage.reasoning_tokens, 20) + + # Test invalid token counts + with self.assertRaises(ValidationError): + TokenUsage(prompt_tokens=-1, completion_tokens=50, total_tokens=150) + with self.assertRaises(ValidationError): + TokenUsage(prompt_tokens=100, completion_tokens=-1, total_tokens=150) + with self.assertRaises(ValidationError): + TokenUsage(prompt_tokens=100, completion_tokens=50, total_tokens=-1) + with self.assertRaises(ValidationError): + TokenUsage(prompt_tokens=100, completion_tokens=50, total_tokens=150, reasoning_tokens=-1) def test_api_response_creation(self): - """Test APIResponse dataclass creation""" + """Test APIResponse creation and validation""" + usage = TokenUsage(prompt_tokens=100, completion_tokens=50, total_tokens=150) + + # Test valid response response = APIResponse( - content="Test", - token_usage=self.test_token_usage, - cost=0.1, - thinking_time=1.0, + content="Test response", + token_usage=usage, + cost=0.123, + thinking_time=1.5, provider="openai", model="o1" ) - self.assertEqual(response.content, "Test") - self.assertEqual(response.token_usage, self.test_token_usage) - self.assertEqual(response.cost, 0.1) - self.assertEqual(response.thinking_time, 1.0) + self.assertEqual(response.content, "Test response") + self.assertEqual(response.token_usage, usage) + self.assertEqual(response.cost, 0.123) + self.assertEqual(response.thinking_time, 1.5) self.assertEqual(response.provider, "openai") self.assertEqual(response.model, "o1") + # Test invalid responses + with self.assertRaises(ValidationError): + APIResponse(content="", token_usage=usage, cost=0.123) + with self.assertRaises(ValidationError): + APIResponse(content="Test", token_usage=usage, cost=-1) + with self.assertRaises(ValidationError): + APIResponse(content="Test", token_usage=usage, cost=0.123, thinking_time=-1) + with self.assertRaises(ValidationError): + APIResponse(content="Test", token_usage=usage, cost=0.123, provider="") + with self.assertRaises(ValidationError): + APIResponse(content="Test", token_usage=usage, cost=0.123, model="") + def test_openai_cost_calculation(self): - """Test OpenAI cost calculation for supported models""" - # Test o1 model pricing - cost = TokenTracker.calculate_openai_cost(1000000, 500000, "o1") - self.assertEqual(cost, 15.0 + 30.0) # $15/M input + $60/M output - - # Test gpt-4o model pricing - cost = TokenTracker.calculate_openai_cost(1000000, 500000, "gpt-4o") - self.assertEqual(cost, 10.0 + 15.0) # $10/M input + $30/M output - - # Test unsupported model - with self.assertRaises(ValueError): - TokenTracker.calculate_openai_cost(1000000, 500000, "gpt-4") + """Test OpenAI cost calculation""" + # Test o1 model costs + cost = TokenTracker.calculate_openai_cost(1000, 500, "o1") + self.assertAlmostEqual(cost, 0.025) # (1000 * 0.01 + 500 * 0.03) / 1000 + + # Test gpt-4 model costs + cost = TokenTracker.calculate_openai_cost(1000, 500, "gpt-4") + self.assertAlmostEqual(cost, 0.06) # (1000 * 0.03 + 500 * 0.06) / 1000 + + # Test gpt-3.5-turbo model costs + cost = TokenTracker.calculate_openai_cost(1000, 500, "gpt-3.5-turbo") + self.assertAlmostEqual(cost, 0.00125) # (1000 * 0.0005 + 500 * 0.0015) / 1000 def test_claude_cost_calculation(self): """Test Claude cost calculation""" - cost = TokenTracker.calculate_claude_cost(1000000, 500000, "claude-3-sonnet-20240229") - self.assertEqual(cost, 3.0 + 7.5) # $3/M input + $15/M output + # Test Claude 3 Opus costs + cost = TokenTracker.calculate_claude_cost(1000, 500, "claude-3-opus-20240229") + self.assertAlmostEqual(cost, 0.0525) # (1000 * 15 + 500 * 75) / 1_000_000 + + # Test Claude 3 Sonnet costs + cost = TokenTracker.calculate_claude_cost(1000, 500, "claude-3-sonnet-20240229") + self.assertAlmostEqual(cost, 0.0105) # (1000 * 3 + 500 * 15) / 1_000_000 + + # Test Claude 3 Haiku costs + cost = TokenTracker.calculate_claude_cost(1000, 500, "claude-3-haiku-20240307") + self.assertAlmostEqual(cost, 0.000875) # (1000 * 0.25 + 500 * 1.25) / 1_000_000 def test_per_day_session_management(self): - """Test per-day session management""" - # Track a request - self.tracker.track_request(self.test_response) + """Test session management with per-day sessions""" + # Create tracker without session ID (should use current date) + tracker = TokenTracker() + tracker.logs_dir = self.test_logs_dir - # Verify file was created - session_file = self.test_logs_dir / f"session_{self.test_session_id}.json" - self.assertTrue(session_file.exists()) + # Track a request + usage = TokenUsage(prompt_tokens=100, completion_tokens=50, total_tokens=150) + response = APIResponse( + content="Test response", + token_usage=usage, + cost=0.123, + thinking_time=1.5, + provider="openai", + model="o1" + ) + tracker.track_request(response) - # Load and verify file contents - with open(session_file) as f: - data = json.load(f) - self.assertEqual(data["session_id"], self.test_session_id) - self.assertEqual(len(data["requests"]), 1) - self.assertEqual(data["requests"][0]["provider"], "openai") - self.assertEqual(data["requests"][0]["model"], "o1") + # Verify request was tracked + self.assertEqual(len(tracker._requests), 1) + self.assertEqual(tracker._requests[0]["provider"], "openai") + self.assertEqual(tracker._requests[0]["model"], "o1") + self.assertEqual(tracker._requests[0]["token_usage"]["total_tokens"], 150) def test_session_file_loading(self): """Test loading existing session file""" @@ -142,73 +187,63 @@ def test_session_file_loading(self): } with open(session_file, "w") as f: json.dump(test_data, f) - + # Create a new tracker - it should load the existing file new_tracker = TokenTracker(self.test_session_id) new_tracker.logs_dir = self.test_logs_dir new_tracker.session_file = self.test_logs_dir / f"session_{self.test_session_id}.json" - self.assertEqual(len(new_tracker.requests), 1) - self.assertEqual(new_tracker.requests[0]["provider"], "openai") - self.assertEqual(new_tracker.requests[0]["model"], "o1") + self.assertEqual(len(new_tracker._requests), 1) + self.assertEqual(new_tracker._requests[0]["provider"], "openai") + self.assertEqual(new_tracker._requests[0]["model"], "o1") + self.assertEqual(new_tracker._requests[0]["token_usage"]["total_tokens"], 150) def test_session_summary_calculation(self): """Test session summary calculation""" - # Add multiple requests with different providers - responses = [ - APIResponse( - content="Test 1", - token_usage=TokenUsage(100, 50, 150, 20), - cost=0.1, - thinking_time=1.0, + tracker = TokenTracker(self.test_session_id) + tracker.logs_dir = self.test_logs_dir + + # Track multiple requests + for i in range(3): + usage = TokenUsage( + prompt_tokens=100 * (i + 1), + completion_tokens=50 * (i + 1), + total_tokens=150 * (i + 1) + ) + response = APIResponse( + content=f"Test response {i}", + token_usage=usage, + cost=0.123 * (i + 1), + thinking_time=1.5 * (i + 1), provider="openai", model="o1" - ), - APIResponse( - content="Test 2", - token_usage=TokenUsage(200, 100, 300, None), - cost=0.2, - thinking_time=2.0, - provider="anthropic", - model="claude-3-sonnet-20240229" ) - ] - - for response in responses: - self.tracker.track_request(response) - - summary = self.tracker.get_session_summary() - - # Verify totals - self.assertEqual(summary["total_requests"], 2) - self.assertEqual(summary["total_prompt_tokens"], 300) - self.assertEqual(summary["total_completion_tokens"], 150) - self.assertEqual(summary["total_tokens"], 450) - self.assertAlmostEqual(summary["total_cost"], 0.3, places=6) - self.assertEqual(summary["total_thinking_time"], 3.0) - - # Verify provider stats - self.assertEqual(len(summary["provider_stats"]), 2) - self.assertEqual(summary["provider_stats"]["openai"]["requests"], 1) - self.assertEqual(summary["provider_stats"]["anthropic"]["requests"], 1) + tracker.track_request(response) + + # Get summary + summary = tracker.get_session_summary() + + # Verify summary calculations + self.assertEqual(len(tracker._requests), 3) + self.assertEqual(summary["total_prompt_tokens"], 600) # 100 + 200 + 300 + self.assertEqual(summary["total_completion_tokens"], 300) # 50 + 100 + 150 + self.assertEqual(summary["total_tokens"], 900) # 150 + 300 + 450 + self.assertAlmostEqual(summary["total_cost"], 0.738, places=3) # 0.123 + 0.246 + 0.369 + self.assertAlmostEqual(summary["total_thinking_time"], 9.0, places=1) # 1.5 + 3.0 + 4.5 def test_global_token_tracker(self): """Test global token tracker instance management""" # Get initial tracker with specific session ID tracker1 = get_token_tracker("test-global-1", logs_dir=self.test_logs_dir) self.assertIsNotNone(tracker1) - + # Get another tracker without session ID - should be the same instance tracker2 = get_token_tracker(logs_dir=self.test_logs_dir) self.assertIs(tracker1, tracker2) - + # Get tracker with different session ID - should be new instance tracker3 = get_token_tracker("test-global-2", logs_dir=self.test_logs_dir) self.assertIsNot(tracker1, tracker3) - self.assertEqual(tracker3.session_id, "test-global-2") - - # Get tracker without session ID - should reuse the latest instance - tracker4 = get_token_tracker(logs_dir=self.test_logs_dir) - self.assertIs(tracker3, tracker4) + self.assertEqual(tracker3._session_id, "test-global-2") if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/tests/test_web_scraper.py b/tests/test_web_scraper.py index 09a1e0f..3d9576f 100644 --- a/tests/test_web_scraper.py +++ b/tests/test_web_scraper.py @@ -1,117 +1,267 @@ -import unittest +#!/usr/bin/env python3 + +import pytest from unittest.mock import patch, MagicMock, AsyncMock +import aiohttp import asyncio -import pytest +import json +from pathlib import Path +from typing import List, Dict, Any +from aiohttp import web +from aiohttp.test_utils import TestClient, TestServer +import html5lib + from tools.web_scraper import ( + fetch_page, validate_url, + validate_max_concurrent, parse_html, - fetch_page, - process_urls + process_urls, + FetchError ) +from tools.common.errors import ValidationError +from tools.common.formatting import format_output -class TestWebScraper(unittest.TestCase): - @classmethod - def setUpClass(cls): - """Set up any necessary test fixtures.""" - cls.mock_response = MagicMock() - cls.mock_response.status = 200 - cls.mock_response.text.return_value = "Test content" - - cls.mock_client_session = MagicMock() - cls.mock_client_session.__aenter__.return_value = cls.mock_client_session - cls.mock_client_session.__aexit__.return_value = None - cls.mock_client_session.get.return_value.__aenter__.return_value = cls.mock_response +class AsyncContextManagerMock: + def __init__(self, response): + self.response = response + + async def __aenter__(self): + return self.response + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass - def setUp(self): - """Set up test fixtures before each test method.""" - self.urls = ["http://example1.com", "http://example2.com"] - self.mock_session = self.mock_client_session +def test_validate_url(): + """Test URL validation""" + # Test valid URLs + assert validate_url('https://example.com') == True + assert validate_url('http://example.com/path?query=1') == True + assert validate_url('https://sub.example.com:8080/path') == True + + # Test invalid URLs + assert validate_url('not-a-url') == False + assert validate_url('http://') == False + assert validate_url('https://') == False + assert validate_url('') == False + assert validate_url('javascript:alert(1)') == False + assert validate_url('ftp://example.com') == False - def test_validate_url(self): - # Test valid URLs - self.assertTrue(validate_url('https://example.com')) - self.assertTrue(validate_url('http://example.com/path?query=1')) - self.assertTrue(validate_url('https://sub.example.com:8080/path')) - - # Test invalid URLs - self.assertFalse(validate_url('not-a-url')) - self.assertFalse(validate_url('http://')) - self.assertFalse(validate_url('https://')) - self.assertFalse(validate_url('')) +def test_validate_max_concurrent(): + """Test concurrent request validation""" + # Test valid values + assert validate_max_concurrent(1) == 1 + assert validate_max_concurrent(10) == 10 + assert validate_max_concurrent(20) == 20 + + # Test invalid values + with pytest.raises(ValidationError) as exc_info: + validate_max_concurrent(0) + assert "must be a positive integer" in str(exc_info.value) + + with pytest.raises(ValidationError) as exc_info: + validate_max_concurrent(-1) + assert "must be a positive integer" in str(exc_info.value) + + with pytest.raises(ValidationError) as exc_info: + validate_max_concurrent(21) + assert "cannot exceed 20" in str(exc_info.value) - def test_parse_html(self): - # Test with empty or None input - self.assertEqual(parse_html(None), "") - self.assertEqual(parse_html(""), "") - - # Test with simple HTML - html = """ - - -

Title

-

Paragraph text

- Link text - - - - - """ - result = parse_html(html) - self.assertIn("Title", result) - self.assertIn("Paragraph text", result) - self.assertIn("[Link text](https://example.com)", result) - self.assertNotIn("var x = 1", result) # Script content should be filtered - self.assertNotIn(".css", result) # Style content should be filtered - - # Test with nested elements - html = """ - - +def test_parse_html(): + """Test HTML parsing""" + # Test with empty input + with pytest.raises(ValidationError): + parse_html("") + + with pytest.raises(ValidationError): + parse_html(" ") + + # Test with simple HTML + html = """ + + +

Title

+

Short text

+

This is a longer paragraph that should be included in the output because it exceeds the minimum length requirement

+ Link text that is long enough to be included in the output + + + + + + + """ + result = parse_html(html) + assert "This is a longer paragraph" in result + assert "[Link text that is long enough to be included in the output](https://example.com)" in result + assert "Short text" not in result # Too short + assert "var x = 1" not in result # In script tag + assert ".css" not in result # In style tag + assert "JavaScript is disabled" not in result # In noscript tag + assert "Advertisement" not in result # In iframe tag + + # Test with complex HTML + html = """ + + +
+

First paragraph with enough text to meet the minimum length requirement for inclusion in results

-

Level 1

-
-

Level 2

-
+ First link with enough text to be included in the output +

Second paragraph that's also long enough to be included in the parsed output

+ JavaScript link to be ignored + Email link to be ignored + Internal link to be ignored
- - - """ - result = parse_html(html) - self.assertIn("Level 1", result) - self.assertIn("Level 2", result) - - # Test with malformed HTML - html = "

Unclosed paragraph" - result = parse_html(html) - self.assertIn("Unclosed paragraph", result) +

+
+ + +
+ + + """ + result = parse_html(html) + assert "First paragraph with enough text" in result + assert "Second paragraph that's also long enough" in result + assert "[First link with enough text to be included in the output](https://example.com/page1)" in result + assert "JavaScript link" not in result + assert "Email link" not in result + assert "Internal link" not in result + assert "adCode" not in result + assert "Ignored ad content" not in result @pytest.mark.asyncio -class TestWebScraperAsync: - @pytest.fixture - def mock_session(self): - mock_response = AsyncMock() - mock_response.status = 200 - mock_response.text = AsyncMock(return_value="Test content") - - mock_client_session = AsyncMock() - mock_client_session.get = AsyncMock(return_value=mock_response) - mock_client_session.__aenter__ = AsyncMock(return_value=mock_client_session) - mock_client_session.__aexit__ = AsyncMock(return_value=None) - return mock_client_session - - async def test_fetch_page(self, mock_session): - """Test fetching a single page.""" - content = await fetch_page("http://example.com", mock_session) - assert content == "Test content" - mock_session.get.assert_called_once_with("http://example.com") +async def test_fetch_page(): + """Test page fetching""" + # Test invalid URL + with pytest.raises(ValidationError): + await fetch_page("not-a-url") + + # Test successful fetch + mock_response = MagicMock() + mock_response.status = 200 + mock_response.text = AsyncMock(return_value="Test content") + + mock_session = MagicMock() + mock_session.get = MagicMock(return_value=AsyncContextManagerMock(mock_response)) + + content = await fetch_page("http://example.com", session=mock_session) + assert content == "Test content" + + # Test HTTP error + mock_response.status = 404 + mock_session.get = MagicMock(return_value=AsyncContextManagerMock(mock_response)) + + with pytest.raises(FetchError) as exc_info: + await fetch_page("http://example.com", session=mock_session) + assert "HTTP error" in str(exc_info.value) + assert "404" in str(exc_info.value) + + # Test timeout error + mock_session.get = MagicMock(side_effect=asyncio.TimeoutError()) + with pytest.raises(FetchError) as exc_info: + await fetch_page("http://example.com", session=mock_session) + assert "Request timed out" in str(exc_info.value) + + # Test network error + mock_session.get = MagicMock(side_effect=aiohttp.ClientError("Network error")) + with pytest.raises(FetchError) as exc_info: + await fetch_page("http://example.com", session=mock_session) + assert "Network error" in str(exc_info.value) - async def test_process_urls(self, mock_session): - """Test processing multiple URLs concurrently.""" - urls = ["http://example1.com", "http://example2.com"] - results = await process_urls(urls, max_concurrent=2, session=mock_session) - assert len(results) == 2 - assert all(content == "Test content" for content in results) - assert mock_session.get.call_count == 2 +@pytest.mark.asyncio +async def test_process_urls(): + """Test URL processing""" + # Test with invalid URLs + with pytest.raises(ValidationError): + await process_urls(["not-a-url"]) + + # Test with empty URL list + with pytest.raises(ValidationError): + await process_urls([]) + + # Test successful processing + mock_response = MagicMock() + mock_response.status = 200 + mock_response.text = AsyncMock(return_value=""" + + +

Test content that is long enough to be included in the results

+ + + """) + + mock_session = MagicMock() + mock_session.get = MagicMock(return_value=AsyncContextManagerMock(mock_response)) + + urls = ["http://example.com", "http://example.org"] + result = await process_urls(urls, session=mock_session) + + assert len(result["results"]) == 2 + assert len(result["errors"]) == 0 + assert "Test content" in result["results"][0]["content"] + + # Test mixed success and failure + def mock_get(url, **kwargs): + if url == "http://example.com": + response = MagicMock() + response.status = 200 + response.text = AsyncMock(return_value=""" + + +

This is a longer piece of content that should definitely meet the minimum length requirement for inclusion in the results. We want to make sure it's processed correctly.

+ + + """) + return AsyncContextManagerMock(response) + else: + response = MagicMock() + response.status = 404 + return AsyncContextManagerMock(response) + + mock_session.get = mock_get + result = await process_urls(urls, session=mock_session) + + assert len(result["results"]) == 1 + assert len(result["errors"]) == 1 + assert "longer piece of content" in result["results"][0]["content"] + assert "http://example.org" in result["errors"] + + # Test max_concurrent limit + urls = [f"http://example.com/{i}" for i in range(5)] + result = await process_urls(urls, max_concurrent=2, session=mock_session) + assert len(result["results"]) + len(result["errors"]) == 5 -if __name__ == '__main__': - unittest.main() +def test_format_output(): + """Test output formatting""" + results = [ + {"url": "http://example.com", "content": "Example text", "timestamp": 123456789}, + {"url": "http://failed.com", "content": "", "timestamp": 123456789} + ] + errors = { + "http://error.com": "Failed to fetch" + } + data = { + "results": results, + "errors": errors + } + + # Test text format + text_output = format_output(data, "text", "Web Scraping Results") + assert "Example text" in text_output + assert "Failed to fetch" in text_output + + # Test JSON format + json_output = format_output(data, "json", "Web Scraping Results") + parsed = json.loads(json_output) + assert len(parsed["data"]["results"]) == 2 + assert parsed["data"]["results"][0]["url"] == "http://example.com" + assert parsed["title"] == "Web Scraping Results" + + # Test markdown format + md_output = format_output(data, "markdown", "Web Scraping Results") + assert "# Web Scraping Results" in md_output + assert "http://example.com" in md_output diff --git a/tools/common/cli.py b/tools/common/cli.py new file mode 100644 index 0000000..3f25e87 --- /dev/null +++ b/tools/common/cli.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 + +import argparse +from typing import Optional, List, Dict, Any +from .logging_config import LogLevel, LogFormat + +def add_common_args(parser: argparse.ArgumentParser) -> None: + """ + Add common arguments to an argument parser. + + Args: + parser: ArgumentParser instance to add arguments to + """ + # Add output format options + parser.add_argument('--format', + choices=['text', 'json', 'markdown'], + default='text', + help='Output format') + + # Add mutually exclusive logging options + log_group = parser.add_mutually_exclusive_group() + log_group.add_argument('--log-level', + choices=['debug', 'info', 'warning', 'error', 'quiet'], + default='info', + help='Set the logging level') + log_group.add_argument('--debug', + action='store_true', + help='Enable debug logging (equivalent to --log-level debug)') + log_group.add_argument('--quiet', + action='store_true', + help='Minimize output (equivalent to --log-level quiet)') + + # Add log format option + parser.add_argument('--log-format', + choices=['text', 'json', 'structured'], + default='text', + help='Log output format') + +def create_parser( + description: str, + *, + add_common: bool = True, + formatter_class: Optional[type] = None +) -> argparse.ArgumentParser: + """ + Create an argument parser with optional common arguments. + + Args: + description: Parser description + add_common: Whether to add common arguments + formatter_class: Optional formatter class + + Returns: + argparse.ArgumentParser: Configured parser + """ + if formatter_class is None: + formatter_class = argparse.ArgumentDefaultsHelpFormatter + + parser = argparse.ArgumentParser( + description=description, + formatter_class=formatter_class + ) + + if add_common: + add_common_args(parser) + + return parser + +def get_log_config(args: argparse.Namespace) -> Dict[str, Any]: + """ + Get logging configuration from parsed arguments. + + Args: + args: Parsed command line arguments + + Returns: + Dict containing log_level and log_format + """ + log_level = LogLevel.DEBUG if args.debug else ( + LogLevel.QUIET if args.quiet else + LogLevel.from_string(args.log_level) + ) + log_format = LogFormat(args.log_format) + + return { + "level": log_level, + "format_type": log_format + } \ No newline at end of file diff --git a/tools/common/errors.py b/tools/common/errors.py new file mode 100644 index 0000000..f48293f --- /dev/null +++ b/tools/common/errors.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 + +from typing import Dict, Any, Optional + +class ToolError(Exception): + """Base class for all tool errors with standardized context""" + + def __init__(self, message: str, context: Optional[Dict[str, Any]] = None): + """ + Initialize error with message and context. + + Args: + message: Error message + context: Optional context dictionary with error details + """ + super().__init__(message) + self.message = message + self.context = context or {} + + def __str__(self) -> str: + """Return formatted error message with context if available""" + if not self.context: + return self.message + + context_str = ", ".join(f"{k}={v}" for k, v in self.context.items()) + return f"{self.message} ({context_str})" + +class ValidationError(ToolError): + """Error for input validation failures""" + pass + +class ConfigError(ToolError): + """Error for configuration and environment issues""" + pass + +class APIError(ToolError): + """Error for external API communication issues""" + def __init__(self, message: str, provider: str, context: Optional[Dict[str, Any]] = None): + """ + Initialize API error. + + Args: + message: Error message + provider: API provider name + context: Optional context dictionary + """ + context = context or {} + context["provider"] = provider + super().__init__(message, context) + self.provider = provider + +class FileError(ToolError): + """Error for file system operations""" + def __init__(self, message: str, path: str, context: Optional[Dict[str, Any]] = None): + """ + Initialize file error. + + Args: + message: Error message + path: File path that caused the error + context: Optional context dictionary + """ + context = context or {} + context["path"] = str(path) + super().__init__(message, context) + self.path = path \ No newline at end of file diff --git a/tools/common/formatting.py b/tools/common/formatting.py new file mode 100644 index 0000000..20f4c26 --- /dev/null +++ b/tools/common/formatting.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python3 + +import json +from datetime import datetime +from typing import Dict, Any, List, Union +from pathlib import Path +from tabulate import tabulate + +def format_cost(cost: float) -> str: + """ + Format a cost value in dollars. + + Args: + cost: Cost value in dollars + + Returns: + str: Formatted cost string + + Raises: + ValueError: If cost is negative + """ + if cost < 0: + raise ValueError("cost must be non-negative") + return f"${cost:.6f}" + +def format_duration(seconds: float) -> str: + """ + Format duration in a human-readable format. + + Args: + seconds: Duration in seconds + + Returns: + str: Formatted duration string + + Raises: + ValueError: If seconds is negative + """ + if seconds < 0: + raise ValueError("seconds must be non-negative") + + if seconds < 60: + return f"{seconds:.2f}s" + minutes = seconds / 60 + if minutes < 60: + return f"{minutes:.2f}m" + hours = minutes / 60 + return f"{hours:.2f}h" + +def format_file_size(size_bytes: int) -> str: + """ + Format file size in human-readable format. + + Args: + size_bytes: Size in bytes + + Returns: + str: Formatted size string + """ + for unit in ['B', 'KB', 'MB', 'GB']: + if size_bytes < 1024: + return f"{size_bytes:.1f}{unit}" + size_bytes /= 1024 + return f"{size_bytes:.1f}TB" + +def format_timestamp(timestamp: float) -> str: + """ + Format Unix timestamp as human-readable date/time. + + Args: + timestamp: Unix timestamp + + Returns: + str: Formatted timestamp string + """ + return datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S') + +def format_output( + data: Union[str, Dict[str, Any], List[Dict[str, Any]]], + format_type: str = 'text', + title: str = None, + metadata: Dict[str, Any] = None +) -> str: + """ + Format data for output in a consistent way. + + Args: + data: Data to format (string, dict, or list of dicts) + format_type: Output format (text, json, or markdown) + title: Optional title for the output + metadata: Optional metadata to include + + Returns: + str: Formatted output string + + Raises: + ValidationError: If format_type is invalid + """ + if format_type not in ['text', 'json', 'markdown']: + raise ValidationError("Invalid output format", { + "format": format_type, + "valid_formats": ['text', 'json', 'markdown'] + }) + + if format_type == 'json': + output = { + "data": data + } + if title: + output["title"] = title + if metadata: + output["metadata"] = metadata + return json.dumps(output, indent=2) + + elif format_type == 'markdown': + output = [] + + if title: + output.extend([f"# {title}\n"]) + + if isinstance(data, str): + output.append(data) + elif isinstance(data, dict): + for key, value in data.items(): + if isinstance(value, dict): + output.extend([f"## {key}", ""]) + for k, v in value.items(): + output.append(f"**{k}**: {v}") + else: + output.append(f"**{key}**: {value}") + elif isinstance(data, list): + for i, item in enumerate(data, 1): + output.extend([f"## Result {i}", ""]) + for key, value in item.items(): + output.append(f"**{key}**: {value}") + output.append("") + + if metadata: + output.extend(["\n---"]) + for key, value in metadata.items(): + output.append(f"*{key}: {value}*") + + return "\n".join(output) + + else: # text format + output = [] + + if title: + output.extend([title, "=" * len(title), ""]) + + if isinstance(data, str): + output = [data] # For string input, just return the string directly + if metadata: # Add metadata for string input too + output.extend(["", "-" * 40]) + for key, value in metadata.items(): + output.append(f"{key}: {value}") + elif isinstance(data, dict): + for key, value in data.items(): + if isinstance(value, dict): + output.extend([f"\n{key}:", ""]) + for k, v in value.items(): + output.append(f"{k}: {v}") + else: + output.append(f"{key}: {value}") + if metadata: + output.extend(["", "-" * 40]) + for key, value in metadata.items(): + output.append(f"{key}: {value}") + elif isinstance(data, list): + for i, item in enumerate(data, 1): + output.extend([f"\nResult {i}:", ""]) + for key, value in item.items(): + output.append(f"{key}: {value}") + if metadata: + output.extend(["", "-" * 40]) + for key, value in metadata.items(): + output.append(f"{key}: {value}") + + return "\n".join(output) \ No newline at end of file diff --git a/tools/common/logging_config.py b/tools/common/logging_config.py new file mode 100644 index 0000000..3c4c65d --- /dev/null +++ b/tools/common/logging_config.py @@ -0,0 +1,153 @@ +import logging +import json +import sys +import os +import threading +from pathlib import Path +from typing import Optional, Union, Dict, Any +from enum import Enum +import uuid + +class LogFormat(Enum): + TEXT = "text" + JSON = "json" + STRUCTURED = "structured" # New format for structured text output + +class LogLevel(Enum): + DEBUG = "debug" + INFO = "info" + WARNING = "warning" + ERROR = "error" + QUIET = "quiet" + + @classmethod + def from_string(cls, level: str) -> "LogLevel": + try: + return cls[level.upper()] + except KeyError: + valid_levels = ", ".join(l.value for l in cls) + raise ValueError(f"Invalid log level: {level}. Valid levels are: {valid_levels}") + + def to_logging_level(self) -> int: + return { + LogLevel.DEBUG: logging.DEBUG, + LogLevel.INFO: logging.INFO, + LogLevel.WARNING: logging.WARNING, + LogLevel.ERROR: logging.ERROR, + LogLevel.QUIET: logging.ERROR + 10 # Higher than ERROR + }[self] + +class StructuredFormatter(logging.Formatter): + """Format log records in a structured text format.""" + def format(self, record: logging.LogRecord) -> str: + # Add correlation ID if not present + if not hasattr(record, 'correlation_id'): + record.correlation_id = str(uuid.uuid4()) + + # Build structured message + parts = [ + f"[{self.formatTime(record)}]", + f"[{record.levelname}]", + f"[{record.name}]", + f"[{record.correlation_id}]", + f"[PID:{os.getpid()}]", + f"[TID:{threading.get_ident()}]", + f"[{record.filename}:{record.lineno}]" + ] + + # Add structured context if available + if hasattr(record, 'context'): + context_str = ' '.join(f"{k}={v}" for k, v in record.context.items()) + parts.append(f"[{context_str}]") + + # Add message + parts.append(record.getMessage()) + + # Add exception info if present + if record.exc_info: + parts.append(self.formatException(record.exc_info)) + + return ' '.join(parts) + +class JSONFormatter(logging.Formatter): + """Format log records as JSON.""" + def format(self, record: logging.LogRecord) -> str: + # Add correlation ID if not present + if not hasattr(record, 'correlation_id'): + record.correlation_id = str(uuid.uuid4()) + + log_data = { + "timestamp": self.formatTime(record), + "level": record.levelname, + "logger": record.name, + "correlation_id": record.correlation_id, + "process_id": os.getpid(), + "thread_id": threading.get_ident(), + "file": record.filename, + "line": record.lineno, + "message": record.getMessage() + } + + # Add structured context if available + if hasattr(record, 'context'): + log_data["context"] = record.context + + # Add exception info if present + if record.exc_info: + log_data["exception"] = { + "type": record.exc_info[0].__name__, + "message": str(record.exc_info[1]), + "traceback": self.formatException(record.exc_info) + } + + # Add error details if available + if hasattr(record, 'error_details'): + log_data["error_details"] = record.error_details + + return json.dumps(log_data) + +def setup_logging( + name: str, + level: Union[LogLevel, str] = LogLevel.INFO, + format_type: LogFormat = LogFormat.TEXT +) -> logging.Logger: + """ + Set up logging with the specified configuration. + + Args: + name: Logger name + level: Log level to use + format_type: Output format (text, json, or structured) + + Returns: + logging.Logger: Configured logger instance + """ + if isinstance(level, str): + level = LogLevel.from_string(level) + + logger = logging.getLogger(name) + logger.setLevel(level.to_logging_level()) + + # Remove any existing handlers + for handler in logger.handlers[:]: + logger.removeHandler(handler) + + # Create console handler + handler = logging.StreamHandler() + handler.setLevel(level.to_logging_level()) + + # Set formatter based on format type + if format_type == LogFormat.JSON: + formatter = JSONFormatter() + elif format_type == LogFormat.STRUCTURED: + formatter = StructuredFormatter() + else: + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + handler.setFormatter(formatter) + logger.addHandler(handler) + + return logger \ No newline at end of file diff --git a/tools/llm_api.py b/tools/llm_api.py index fee4954..184f5d8 100644 --- a/tools/llm_api.py +++ b/tools/llm_api.py @@ -1,4 +1,4 @@ -#!/usr/bin/env /workspace/tmp_windsurf/venv/bin/python3 +#!/usr/bin/env python3 import google.generativeai as genai from openai import OpenAI, AzureOpenAI @@ -9,148 +9,367 @@ from pathlib import Path import sys import base64 -from typing import Optional, Union, List import mimetypes import time +import json +from typing import Optional, Union, List, Dict, Any from .token_tracker import TokenUsage, APIResponse, get_token_tracker +from .common.logging_config import setup_logging, LogLevel, LogFormat -def load_environment(): - """Load environment variables from .env files in order of precedence""" - # Order of precedence: - # 1. System environment variables (already loaded) - # 2. .env.local (user-specific overrides) - # 3. .env (project defaults) - # 4. .env.example (example configuration) +logger = setup_logging(__name__) + +class LLMApiError(Exception): + """Custom exception for LLM API failures""" + def __init__(self, message: str, provider: str, details: Optional[Dict[str, Any]] = None): + self.provider = provider + self.details = details or {} + super().__init__(message) + +class FileError(Exception): + """Custom exception for file reading failures""" + def __init__(self, message: str, file_path: str, details: Optional[Dict[str, Any]] = None): + self.file_path = file_path + self.details = details or {} + super().__init__(message) + +def load_environment() -> bool: + """ + Load environment variables from .env files in order of precedence. + Returns: + bool: True if any environment file was loaded + + Note: + Order of precedence: + 1. System environment variables (already loaded) + 2. .env.local (user-specific overrides) + 3. .env (project defaults) + 4. .env.example (example configuration) + """ env_files = ['.env.local', '.env', '.env.example'] env_loaded = False - print("Current working directory:", Path('.').absolute(), file=sys.stderr) - print("Looking for environment files:", env_files, file=sys.stderr) + logger.debug("Loading environment variables", extra={ + "context": { + "working_directory": str(Path('.').absolute()), + "env_files": env_files + } + }) for env_file in env_files: env_path = Path('.') / env_file - print(f"Checking {env_path.absolute()}", file=sys.stderr) + logger.debug("Checking environment file", extra={ + "context": { + "file": str(env_path.absolute()) + } + }) if env_path.exists(): - print(f"Found {env_file}, loading variables...", file=sys.stderr) + logger.info("Loading environment file", extra={ + "context": { + "file": env_file + } + }) load_dotenv(dotenv_path=env_path) env_loaded = True - print(f"Loaded environment variables from {env_file}", file=sys.stderr) - # Print loaded keys (but not values for security) + + # Log loaded keys (but not values for security) with open(env_path) as f: keys = [line.split('=')[0].strip() for line in f if '=' in line and not line.startswith('#')] - print(f"Keys loaded from {env_file}: {keys}", file=sys.stderr) + logger.debug("Loaded environment variables", extra={ + "context": { + "file": env_file, + "keys": keys + } + }) if not env_loaded: - print("Warning: No .env files found. Using system environment variables only.", file=sys.stderr) - print("Available system environment variables:", list(os.environ.keys()), file=sys.stderr) - -# Load environment variables at module import -load_environment() + logger.warning("No environment files found", extra={ + "context": { + "system_env_keys": list(os.environ.keys()) + } + }) + + return env_loaded def encode_image_file(image_path: str) -> tuple[str, str]: """ Encode an image file to base64 and determine its MIME type. Args: - image_path (str): Path to the image file + image_path: Path to the image file Returns: tuple: (base64_encoded_string, mime_type) + + Raises: + FileError: If the image file cannot be read or encoded, or if format is unsupported """ + path = Path(image_path) + if not path.exists(): + logger.error("Image file not found", extra={ + "context": { + "path": str(path) + } + }) + raise FileError("Image file not found", str(path)) + mime_type, _ = mimetypes.guess_type(image_path) if not mime_type: - mime_type = 'image/png' # Default to PNG if type cannot be determined + logger.warning("Could not determine MIME type", extra={ + "context": { + "path": str(path), + "default_mime_type": "image/png" + } + }) + mime_type = 'image/png' + elif mime_type not in ['image/png', 'image/jpeg']: + logger.error("Unsupported image format", extra={ + "context": { + "path": str(path), + "mime_type": mime_type, + "supported_formats": ['image/png', 'image/jpeg'] + } + }) + raise FileError("Unsupported image format", str(path), { + "mime_type": mime_type, + "supported_formats": ['image/png', 'image/jpeg'] + }) + + try: + with open(path, "rb") as image_file: + encoded_string = base64.b64encode(image_file.read()).decode('utf-8') + logger.debug("Successfully encoded image", extra={ + "context": { + "path": str(path), + "encoded_size": len(encoded_string) + } + }) + return encoded_string, mime_type + except Exception as e: + logger.error("Failed to read/encode image file", extra={ + "context": { + "path": str(path), + "error": str(e) + } + }) + raise FileError("Failed to read/encode image file", str(path), { + "error": str(e) + }) + +def create_llm_client(provider: str = "openai") -> Any: + """ + Create an LLM client with proper authentication. + + Args: + provider: The API provider to use - with open(image_path, "rb") as image_file: - encoded_string = base64.b64encode(image_file.read()).decode('utf-8') + Returns: + Any: Configured LLM client - return encoded_string, mime_type + Raises: + LLMApiError: If client creation fails or API key is missing + """ + logger.debug("Creating LLM client", extra={ + "context": { + "provider": provider + } + }) + + try: + if provider == "openai": + api_key = os.getenv('OPENAI_API_KEY') + if not api_key: + logger.error("Missing API key", extra={ + "context": { + "provider": provider, + "required_env": "OPENAI_API_KEY" + } + }) + raise LLMApiError("OPENAI_API_KEY not found in environment variables", provider) + return OpenAI(api_key=api_key) + + elif provider == "azure": + api_key = os.getenv('AZURE_OPENAI_API_KEY') + if not api_key: + logger.error("Missing API key", extra={ + "context": { + "provider": provider, + "required_env": "AZURE_OPENAI_API_KEY" + } + }) + raise LLMApiError("AZURE_OPENAI_API_KEY not found in environment variables", provider) + return AzureOpenAI( + api_key=api_key, + api_version="2024-08-01-preview", + azure_endpoint="https://msopenai.openai.azure.com" + ) + + elif provider == "deepseek": + api_key = os.getenv('DEEPSEEK_API_KEY') + if not api_key: + logger.error("Missing API key", extra={ + "context": { + "provider": provider, + "required_env": "DEEPSEEK_API_KEY" + } + }) + raise LLMApiError("DEEPSEEK_API_KEY not found in environment variables", provider) + return OpenAI( + api_key=api_key, + base_url="https://api.deepseek.com/v1", + ) + + elif provider == "anthropic": + api_key = os.getenv('ANTHROPIC_API_KEY') + if not api_key: + logger.error("Missing API key", extra={ + "context": { + "provider": provider, + "required_env": "ANTHROPIC_API_KEY" + } + }) + raise LLMApiError("ANTHROPIC_API_KEY not found in environment variables", provider) + return Anthropic(api_key=api_key) + + elif provider == "gemini": + api_key = os.getenv('GOOGLE_API_KEY') + if not api_key: + logger.error("Missing API key", extra={ + "context": { + "provider": provider, + "required_env": "GOOGLE_API_KEY" + } + }) + raise LLMApiError("GOOGLE_API_KEY not found in environment variables", provider) + genai.configure(api_key=api_key) + return genai + + elif provider == "local": + return OpenAI( + base_url="http://192.168.180.137:8006/v1", + api_key="not-needed" + ) + + else: + logger.error("Unsupported provider", extra={ + "context": { + "provider": provider, + "supported_providers": ["openai", "azure", "deepseek", "anthropic", "gemini", "local"] + } + }) + raise LLMApiError(f"Unsupported provider: {provider}", provider) + + except Exception as e: + if isinstance(e, LLMApiError): + raise + logger.error("Failed to create LLM client", extra={ + "context": { + "provider": provider, + "error": str(e) + } + }) + raise LLMApiError(f"Failed to create {provider} client: {e}", provider) -def create_llm_client(provider="openai"): - if provider == "openai": - api_key = os.getenv('OPENAI_API_KEY') - if not api_key: - raise ValueError("OPENAI_API_KEY not found in environment variables") - return OpenAI( - api_key=api_key - ) - elif provider == "azure": - api_key = os.getenv('AZURE_OPENAI_API_KEY') - if not api_key: - raise ValueError("AZURE_OPENAI_API_KEY not found in environment variables") - return AzureOpenAI( - api_key=api_key, - api_version="2024-08-01-preview", - azure_endpoint="https://msopenai.openai.azure.com" - ) - elif provider == "deepseek": - api_key = os.getenv('DEEPSEEK_API_KEY') - if not api_key: - raise ValueError("DEEPSEEK_API_KEY not found in environment variables") - return OpenAI( - api_key=api_key, - base_url="https://api.deepseek.com/v1", - ) - elif provider == "anthropic": - api_key = os.getenv('ANTHROPIC_API_KEY') - if not api_key: - raise ValueError("ANTHROPIC_API_KEY not found in environment variables") - return Anthropic( - api_key=api_key - ) - elif provider == "gemini": - api_key = os.getenv('GOOGLE_API_KEY') - if not api_key: - raise ValueError("GOOGLE_API_KEY not found in environment variables") - genai.configure(api_key=api_key) - return genai - elif provider == "local": - return OpenAI( - base_url="http://192.168.180.137:8006/v1", - api_key="not-needed" - ) - else: - raise ValueError(f"Unsupported provider: {provider}") +def get_default_model(provider: str) -> str: + """ + Get the default model name for a provider. + + Args: + provider: The API provider + + Returns: + str: Default model name + + Raises: + LLMApiError: If provider is invalid + """ + defaults = { + "openai": "gpt-4o", + "azure": os.getenv('AZURE_OPENAI_MODEL_DEPLOYMENT', 'gpt-4o-ms'), + "deepseek": "deepseek-chat", + "anthropic": "claude-3-sonnet-20240229", + "gemini": "gemini-pro", + "local": "Qwen/Qwen2.5-32B-Instruct-AWQ" + } + + model = defaults.get(provider) + if not model: + logger.error("Invalid provider for default model", extra={ + "context": { + "provider": provider, + "supported_providers": list(defaults.keys()) + } + }) + raise LLMApiError(f"Invalid provider: {provider}", provider) + + return model -def query_llm(prompt: str, client=None, model=None, provider="openai", image_path: Optional[str] = None) -> Optional[str]: +def query_llm( + prompt: str, + client: Optional[Any] = None, + model: Optional[str] = None, + provider: str = "openai", + image_path: Optional[str] = None, + system_content: Optional[str] = None +) -> Optional[str]: """ - Query an LLM with a prompt and optional image attachment. + Query an LLM with a prompt and optional image. Args: - prompt (str): The text prompt to send - client: The LLM client instance - model (str, optional): The model to use - provider (str): The API provider to use - image_path (str, optional): Path to an image file to attach + prompt: The prompt to send to the LLM + client: Optional pre-configured LLM client + model: Optional model name to use + provider: LLM provider to use + image_path: Optional path to an image file to include + system_content: Optional system prompt Returns: - Optional[str]: The LLM's response or None if there was an error + str: LLM response + + Raises: + LLMApiError: If there are issues with the LLM query + FileError: If there are issues with image input or file reading """ - if client is None: - client = create_llm_client(provider) + start_time = time.time() try: - # Set default model - if model is None: - if provider == "openai": - model = "gpt-4o" - elif provider == "azure": - model = os.getenv('AZURE_OPENAI_MODEL_DEPLOYMENT', 'gpt-4o-ms') # Get from env with fallback - elif provider == "deepseek": - model = "deepseek-chat" - elif provider == "anthropic": - model = "claude-3-5-sonnet-20241022" - elif provider == "gemini": - model = "gemini-pro" - elif provider == "local": - model = "Qwen/Qwen2.5-32B-Instruct-AWQ" + if not client: + client = create_llm_client(provider) - start_time = time.time() + if not model: + model = get_default_model(provider) + + # Check for image support + if image_path: + if provider not in ["openai", "gemini"]: + raise FileError("Image input not supported", provider, { + "provider": provider, + "supported_providers": ["openai", "gemini"] + }) + + # Encode image if path provided + base64_image, mime_type = encode_image_file(image_path) + + logger.info("Querying LLM", extra={ + "context": { + "provider": provider, + "model": model, + "prompt_length": len(prompt), + "has_image": image_path is not None, + "has_system_content": system_content is not None + } + }) if provider in ["openai", "local", "deepseek", "azure"]: messages = [{"role": "user", "content": []}] + # Add system content if provided + if system_content: + messages[0]["content"].append({ + "type": "text", + "text": system_content + }) + # Add text content messages[0]["content"].append({ "type": "text", @@ -159,12 +378,10 @@ def query_llm(prompt: str, client=None, model=None, provider="openai", image_pat # Add image content if provided if image_path: - if provider == "openai": - encoded_image, mime_type = encode_image_file(image_path) - messages[0]["content"] = [ - {"type": "text", "text": prompt}, - {"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{encoded_image}"}} - ] + messages[0]["content"].append({ + "type": "image_url", + "image_url": {"url": f"data:{mime_type};base64,{base64_image}"} + }) kwargs = { "model": model, @@ -207,11 +424,34 @@ def query_llm(prompt: str, client=None, model=None, provider="openai", image_pat ) get_token_tracker().track_request(api_response) + logger.info("LLM response received", extra={ + "context": { + "provider": provider, + "model": model, + "elapsed_seconds": thinking_time, + "response_length": len(response.choices[0].message.content), + "token_usage": { + "prompt": token_usage.prompt_tokens, + "completion": token_usage.completion_tokens, + "total": token_usage.total_tokens, + "reasoning": token_usage.reasoning_tokens + }, + "cost": cost + } + }) + return response.choices[0].message.content elif provider == "anthropic": messages = [{"role": "user", "content": []}] + # Add system content if provided + if system_content: + messages[0]["content"].append({ + "type": "text", + "text": system_content + }) + # Add text content messages[0]["content"].append({ "type": "text", @@ -220,13 +460,12 @@ def query_llm(prompt: str, client=None, model=None, provider="openai", image_pat # Add image content if provided if image_path: - encoded_image, mime_type = encode_image_file(image_path) messages[0]["content"].append({ "type": "image", "source": { "type": "base64", "media_type": mime_type, - "data": encoded_image + "data": base64_image } }) @@ -262,43 +501,245 @@ def query_llm(prompt: str, client=None, model=None, provider="openai", image_pat ) get_token_tracker().track_request(api_response) + logger.info("LLM response received", extra={ + "context": { + "provider": provider, + "model": model, + "elapsed_seconds": thinking_time, + "response_length": len(response.content[0].text), + "token_usage": { + "prompt": token_usage.prompt_tokens, + "completion": token_usage.completion_tokens, + "total": token_usage.total_tokens + }, + "cost": cost + } + }) + return response.content[0].text elif provider == "gemini": model = client.GenerativeModel(model) response = model.generate_content(prompt) + thinking_time = time.time() - start_time + + logger.info("LLM response received", extra={ + "context": { + "provider": provider, + "model": model, + "elapsed_seconds": thinking_time, + "response_length": len(response.text) + } + }) + return response.text except Exception as e: - print(f"Error querying LLM: {e}", file=sys.stderr) - return None + if isinstance(e, (LLMApiError, FileError)): + raise + logger.error("LLM query failed", extra={ + "context": { + "provider": provider, + "model": model, + "error": str(e), + "elapsed_seconds": time.time() - start_time + } + }) + raise LLMApiError(f"Failed to query {provider} LLM: {e}", provider, { + "model": model, + "error": str(e), + "elapsed_seconds": time.time() - start_time + }) + +def read_content_or_file(value: str) -> str: + """ + Read content from a string or file if prefixed with @. + + Args: + value: String value or @file path + + Returns: + str: Content from string or file + + Raises: + FileError: If file cannot be read + """ + if not value: + return "" + + if value.startswith('@'): + file_path = value[1:] # Remove @ prefix + try: + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + logger.debug("Read file content", extra={ + "context": { + "path": file_path, + "content_length": len(content) + } + }) + return content + except Exception as e: + logger.error("Failed to read file", extra={ + "context": { + "path": file_path, + "error": str(e) + } + }) + raise FileError("Failed to read file", file_path, { + "error": str(e) + }) + return value def main(): - parser = argparse.ArgumentParser(description='Query an LLM with a prompt') - parser.add_argument('--prompt', type=str, help='The prompt to send to the LLM', required=True) - parser.add_argument('--provider', choices=['openai','anthropic','gemini','local','deepseek','azure'], default='openai', help='The API provider to use') + parser = argparse.ArgumentParser( + description='Query an LLM with a prompt', + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument('--prompt', type=str, help='The prompt to send to the LLM (prefix with @ to read from file)', required=True) + parser.add_argument('--system', type=str, help='System prompt to use (prefix with @ to read from file)') + parser.add_argument('--file', type=str, action='append', help='Path to a file to include in the prompt. Can be specified multiple times.') + parser.add_argument('--provider', + choices=['openai', 'anthropic', 'gemini', 'local', 'deepseek', 'azure'], + default='openai', + help='The API provider to use') parser.add_argument('--model', type=str, help='The model to use (default depends on provider)') parser.add_argument('--image', type=str, help='Path to an image file to attach to the prompt') + + # Add output format options + parser.add_argument('--format', + choices=['text', 'json', 'markdown'], + default='text', + help='Output format') + + # Add mutually exclusive logging options + log_group = parser.add_mutually_exclusive_group() + log_group.add_argument('--log-level', + choices=['debug', 'info', 'warning', 'error', 'quiet'], + default='info', + help='Set the logging level') + log_group.add_argument('--debug', + action='store_true', + help='Enable debug logging (equivalent to --log-level debug)') + log_group.add_argument('--quiet', + action='store_true', + help='Minimize output (equivalent to --log-level quiet)') + + # Add log format option + parser.add_argument('--log-format', + choices=['text', 'json', 'structured'], + default='text', + help='Log output format') + args = parser.parse_args() - if not args.model: - if args.provider == 'openai': - args.model = "gpt-4o" - elif args.provider == "deepseek": - args.model = "deepseek-chat" - elif args.provider == 'anthropic': - args.model = "claude-3-5-sonnet-20241022" - elif args.provider == 'gemini': - args.model = "gemini-2.0-flash-exp" - elif args.provider == 'azure': - args.model = os.getenv('AZURE_OPENAI_MODEL_DEPLOYMENT', 'gpt-4o-ms') # Get from env with fallback + # Configure logging + log_level = LogLevel.DEBUG if args.debug else ( + LogLevel.QUIET if args.quiet else + LogLevel.from_string(args.log_level) + ) + log_format = LogFormat(args.log_format) + logger = setup_logging(__name__, level=log_level, format_type=log_format) + logger.debug("Debug logging enabled", extra={ + "context": { + "log_level": log_level.value, + "log_format": log_format.value + } + }) + + try: + # Load environment variables + load_environment() + + # Get default model if not specified + if not args.model: + args.model = get_default_model(args.provider) + logger.debug("Using default model", extra={ + "context": { + "provider": args.provider, + "model": args.model + } + }) + + # Read prompt and system content + prompt = read_content_or_file(args.prompt) + system_content = read_content_or_file(args.system) if args.system else "" + + # Read file contents if specified + file_contents = [] + if args.file: + for file_path in args.file: + try: + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + file_contents.append(f"\nFile: {file_path}\n{content}") + logger.debug("Read included file", extra={ + "context": { + "path": file_path, + "content_length": len(content) + } + }) + except Exception as e: + logger.error("Failed to read included file", extra={ + "context": { + "path": file_path, + "error": str(e) + } + }) + raise FileError(f"Failed to read file {file_path}: {e}", file_path, { + "error": str(e) + }) + + # Combine prompt with file contents + if file_contents: + prompt = prompt + "\n\nIncluded files:" + "\n---".join(file_contents) - client = create_llm_client(args.provider) - response = query_llm(args.prompt, client, model=args.model, provider=args.provider, image_path=args.image) - if response: - print(response) - else: - print("Failed to get response from LLM") + # Create client and query LLM + client = create_llm_client(args.provider) + response = query_llm( + prompt, + client, + args.model, + args.provider, + args.image, + system_content + ) + + if response: + if args.format == 'json': + print(json.dumps({ + "response": response, + "model": args.model, + "provider": args.provider + }, indent=2)) + elif args.format == 'markdown': + print(f"# LLM Response\n\n{response}\n\n---\n*Model: {args.model} ({args.provider})*") + else: + print(response) + else: + logger.error("No response received", extra={ + "context": { + "provider": args.provider, + "model": args.model + } + }) + sys.exit(1) + + except FileError as e: + logger.error(str(e), extra={ + "context": { + "file_path": e.file_path, + "details": e.details + } + }) + sys.exit(1) + except Exception as e: + logger.error("Unexpected error", extra={ + "context": { + "error": str(e) + } + }) + sys.exit(1) if __name__ == "__main__": main() \ No newline at end of file diff --git a/tools/plan_exec_llm.py b/tools/plan_exec_llm.py index 9861924..3c76b14 100644 --- a/tools/plan_exec_llm.py +++ b/tools/plan_exec_llm.py @@ -7,82 +7,369 @@ from dotenv import load_dotenv import sys import time +import json +import aiohttp +from typing import Optional, Dict, Any, Union, List from .token_tracker import TokenUsage, APIResponse, get_token_tracker +from .common.logging_config import setup_logging, LogLevel, LogFormat +from .common.errors import ToolError, ValidationError, APIError, FileError +from .common.formatting import format_output, format_duration, format_cost +from .common.cli import create_parser, get_log_config + +logger = setup_logging(__name__) STATUS_FILE = '.cursorrules' -def load_environment(): - """Load environment variables from .env files""" +def validate_file_path(path: Union[str, Path], must_exist: bool = True) -> Path: + """ + Validate a file path. + + Args: + path: Path to validate + must_exist: Whether the file must exist + + Returns: + Path: Validated Path object + + Raises: + FileError: If path is invalid or file doesn't exist when required + """ + try: + path_obj = Path(path) + if must_exist and not path_obj.exists(): + logger.error("File not found", extra={ + "context": { + "path": str(path_obj), + "must_exist": must_exist + } + }) + raise FileError("File not found", str(path_obj)) + if path_obj.exists() and not path_obj.is_file(): + logger.error("Path exists but is not a file", extra={ + "context": { + "path": str(path_obj) + } + }) + raise FileError("Not a file", str(path_obj)) + return path_obj + except Exception as e: + if isinstance(e, FileError): + raise + logger.error("Invalid file path", extra={ + "context": { + "path": str(path), + "error": str(e) + } + }) + raise FileError("Invalid file path", str(path), { + "error": str(e) + }) + +def load_environment() -> bool: + """ + Load environment variables from .env files in order of precedence. + + Returns: + bool: True if any environment file was loaded + + Note: + Order of precedence: + 1. System environment variables (already loaded) + 2. .env.local (user-specific overrides) + 3. .env (project defaults) + 4. .env.example (example configuration) + """ env_files = ['.env.local', '.env', '.env.example'] env_loaded = False + logger.debug("Loading environment variables", extra={ + "context": { + "working_directory": str(Path('.').absolute()), + "env_files": env_files + } + }) + for env_file in env_files: env_path = Path('.') / env_file + logger.debug("Checking environment file", extra={ + "context": { + "file": str(env_path.absolute()) + } + }) if env_path.exists(): + logger.info("Loading environment file", extra={ + "context": { + "file": env_file + } + }) load_dotenv(dotenv_path=env_path) env_loaded = True - break + + # Log loaded keys (but not values for security) + with open(env_path) as f: + keys = [line.split('=')[0].strip() for line in f if '=' in line and not line.startswith('#')] + logger.debug("Loaded environment variables", extra={ + "context": { + "file": env_file, + "keys": keys + } + }) if not env_loaded: - print("Warning: No .env files found. Using system environment variables only.", file=sys.stderr) + logger.warning("No environment files found", extra={ + "context": { + "system_env_keys": list(os.environ.keys()) + } + }) + + return env_loaded -def read_plan_status(): - """Read the content of the plan status file, only including content after Multi-Agent Scratchpad""" - status_file = STATUS_FILE +def read_plan_status() -> str: + """ + Read the content of the plan status file, only including content after Multi-Agent Scratchpad. + + Returns: + str: Content of the Multi-Agent Scratchpad section + + Raises: + FileError: If there are issues reading the file + ValidationError: If section not found + """ try: + status_file = validate_file_path(STATUS_FILE) + logger.debug("Reading status file", extra={ + "context": { + "file": str(status_file) + } + }) + with open(status_file, 'r', encoding='utf-8') as f: content = f.read() + # Find the Multi-Agent Scratchpad section scratchpad_marker = "# Multi-Agent Scratchpad" - if scratchpad_marker in content: - return content[content.index(scratchpad_marker):] - else: - print(f"Warning: '{scratchpad_marker}' section not found in {status_file}", file=sys.stderr) - return "" + if scratchpad_marker not in content: + logger.error("Multi-Agent Scratchpad section not found", extra={ + "context": { + "file": str(status_file), + "marker": scratchpad_marker + } + }) + raise ValidationError( + f"'{scratchpad_marker}' section not found in status file", + { + "file": str(status_file), + "marker": scratchpad_marker + } + ) + + section_content = content[content.index(scratchpad_marker) + len(scratchpad_marker):] + if not section_content.strip(): + logger.error("Empty Multi-Agent Scratchpad section", extra={ + "context": { + "file": str(status_file) + } + }) + raise ValidationError( + "Multi-Agent Scratchpad section is empty", + { + "file": str(status_file) + } + ) + + logger.debug("Found Multi-Agent Scratchpad section", extra={ + "context": { + "file": str(status_file), + "content_length": len(section_content) + } + }) + return section_content + except Exception as e: - print(f"Error reading {status_file}: {e}", file=sys.stderr) - return "" + if isinstance(e, (FileError, ValidationError)): + raise + logger.error("Failed to read status file", extra={ + "context": { + "file": STATUS_FILE, + "error": str(e) + } + }) + raise FileError("Failed to read status file", STATUS_FILE, { + "error": str(e) + }) -def read_file_content(file_path): - """Read content from a specified file""" +def read_file_content(file_path: str) -> str: + """ + Read content from a specified file. + + Args: + file_path: Path to the file to read + + Returns: + str: File content + + Raises: + FileError: If there are issues reading the file + """ try: - with open(file_path, 'r', encoding='utf-8') as f: - return f.read() + path = validate_file_path(file_path) + logger.debug("Reading file", extra={ + "context": { + "path": str(path) + } + }) + + with open(path, 'r', encoding='utf-8') as f: + content = f.read() + if not content.strip(): + logger.warning("File is empty", extra={ + "context": { + "path": str(path) + } + }) + else: + logger.debug("Successfully read file", extra={ + "context": { + "path": str(path), + "content_length": len(content) + } + }) + return content + except Exception as e: - print(f"Error reading {file_path}: {e}", file=sys.stderr) - return None + if isinstance(e, FileError): + raise + logger.error("Failed to read file", extra={ + "context": { + "path": file_path, + "error": str(e) + } + }) + raise FileError("Failed to read file", file_path, { + "error": str(e) + }) -def create_llm_client(): - """Create OpenAI client""" +def create_llm_client() -> OpenAI: + """ + Create OpenAI client with proper authentication. + + Returns: + OpenAI: Configured OpenAI client + + Raises: + PlanExecError: If API key is missing or invalid + """ api_key = os.getenv('OPENAI_API_KEY') if not api_key: - raise ValueError("OPENAI_API_KEY not found in environment variables") - return OpenAI(api_key=api_key) + logger.error("Missing API key", extra={ + "context": { + "required_env": "OPENAI_API_KEY" + } + }) + raise FileError( + "OPENAI_API_KEY not found in environment variables", + "client_creation" + ) + + try: + logger.debug("Creating OpenAI client") + return OpenAI(api_key=api_key) + except Exception as e: + logger.error("Failed to create OpenAI client", extra={ + "context": { + "error": str(e) + } + }) + raise FileError( + f"Failed to create OpenAI client: {e}", + "client_creation" + ) -def query_llm(plan_content, user_prompt=None, file_content=None): - """Query the LLM with combined prompts""" - client = create_llm_client() +def read_content_or_file(value: str) -> str: + """ + Read content from a string or file if prefixed with @. - # Combine prompts - system_prompt = """""" + Args: + value: String value or @file path + + Returns: + str: Content from string or file + + Raises: + FileError: If file cannot be read + """ + if not value: + return "" - combined_prompt = f"""You are working on a multi-agent context. The executor is the one who actually does the work. And you are the planner. Now the executor is asking you for help. Please analyze the provided project plan and status, then address the executor's specific query or request. + if value.startswith('@'): + file_path = value[1:] # Remove @ prefix + try: + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + logger.debug("Read file content", extra={ + "context": { + "path": file_path, + "content_length": len(content) + } + }) + return content + except Exception as e: + raise FileError("Failed to read file", file_path, { + "error": str(e) + }) + return value -You need to think like a founder. Prioritize agility and don't over-engineer. Think deep. Try to foresee challenges and derisk earlier. If opportunity sizing or probing experiments can reduce risk with low cost, instruct the executor to do them. +def query_llm(plan_content: str, user_prompt: Optional[str] = None, file_content: Optional[str] = None, system_prompt: Optional[str] = None) -> str: + """ + Query the LLM with combined prompts. -Project Plan and Status: + Args: + plan_content: Current plan and status content + user_prompt: Optional additional user prompt + file_content: Optional content from a specific file + system_prompt: Optional system prompt to override default + + Returns: + str: LLM response + + Raises: + ValidationError: If inputs are invalid + APIError: If there are issues with the LLM query + """ + if not plan_content or not plan_content.strip(): + logger.error("Empty plan content", extra={ + "context": { + "has_user_prompt": user_prompt is not None, + "has_file_content": file_content is not None + } + }) + raise ValidationError("Plan content cannot be empty", { + "has_user_prompt": user_prompt is not None, + "has_file_content": file_content is not None + }) + + try: + client = create_llm_client() + + # Initialize system prompt + system_prompt = system_prompt or """You are working on a multi-agent context. The executor is the one who actually does the work. And you are the planner. Now the executor is asking you for help. Please analyze the provided project plan and status, then address the executor's specific query or request. + +You need to think like a founder. Prioritize agility and don't over-engineer. Think deep. Try to foresee challenges and derisk earlier. If opportunity sizing or probing experiments can reduce risk with low cost, instruct the executor to do them.""" + + # Combine prompts with exact text preserved + combined_prompt = f"""Project Plan and Status: ====== {plan_content} ====== """ - - if file_content: - combined_prompt += f"\nFile Content:\n======\n{file_content}\n======\n" - - if user_prompt: - combined_prompt += f"\nUser Query:\n{user_prompt}\n" - - combined_prompt += """\nYour response should be focusing on revising the Multi-Agent Scratchpad section in the .cursorrules file. There is no need to regenerate the entire document. You can use the following format to prompt how to revise the document: + + if file_content: + combined_prompt += f"\nFile Content:\n======\n{file_content}\n======\n" + + if user_prompt: + combined_prompt += f"\nUser Query:\n{user_prompt}\n" + + combined_prompt += """\nYour response should be focusing on revising the Multi-Agent Scratchpad section in the .cursorrules file. There is no need to regenerate the entire document. You can use the following format to prompt how to revise the document: <<<<<< @@ -90,10 +377,17 @@ def query_llm(plan_content, user_prompt=None, file_content=None): >>>>>>> -We will do the actual changes in the .cursorrules file. -""" - - try: +We will do the actual changes in the .cursorrules file.""" + + logger.info("Querying LLM", extra={ + "context": { + "prompt_length": len(combined_prompt), + "has_user_prompt": user_prompt is not None, + "has_file_content": file_content is not None, + "has_system_prompt": system_prompt is not None + } + }) + start_time = time.time() response = client.chat.completions.create( model="o1", @@ -132,40 +426,377 @@ def query_llm(plan_content, user_prompt=None, file_content=None): ) get_token_tracker().track_request(api_response) + logger.info("LLM response received", extra={ + "context": { + "elapsed_seconds": thinking_time, + "response_length": len(response.choices[0].message.content), + "token_usage": { + "prompt": token_usage.prompt_tokens, + "completion": token_usage.completion_tokens, + "total": token_usage.total_tokens, + "reasoning": token_usage.reasoning_tokens + }, + "cost": cost + } + }) + return response.choices[0].message.content + except Exception as e: - print(f"Error querying LLM: {e}", file=sys.stderr) - return None + if isinstance(e, (ValidationError, APIError)): + raise + logger.error("Failed to query LLM", extra={ + "context": { + "error": str(e), + "elapsed_seconds": time.time() - start_time + } + }) + raise APIError("Failed to query LLM", "openai", { + "error": str(e), + "elapsed_seconds": time.time() - start_time + }) -def main(): - parser = argparse.ArgumentParser(description='Query OpenAI o1 model with project plan context') - parser.add_argument('--prompt', type=str, help='Additional prompt to send to the LLM', required=False) - parser.add_argument('--file', type=str, help='Path to a file whose content should be included in the prompt', required=False) - args = parser.parse_args() +def format_output(response: str, format_type: str = 'text') -> str: + """ + Format the LLM response for display. + + Args: + response: Raw LLM response + format_type: Output format (text, json, or markdown) + + Returns: + str: Formatted output string + + Raises: + PlanExecError: If response is invalid + ValidationError: If format_type is invalid + """ + if not response or not response.strip(): + logger.error("Empty LLM response", extra={ + "context": { + "format_type": format_type + } + }) + raise FileError("Empty response from LLM", "output_formatting") + + if format_type == 'json': + return json.dumps({ + "response": response, + "model": "o1", + "provider": "openai", + "context": "plan_exec" + }, indent=2) + elif format_type == 'markdown': + return f"""# Plan Execution Response - # Load environment variables - load_environment() +{response} - # Read plan status - plan_content = read_plan_status() +--- +*Model: o1 (OpenAI)* +*Context: Plan Execution*""" + elif format_type == 'text': + # Add clear section markers for text format + sections = [ + 'Following is the instruction on how to revise the Multi-Agent Scratchpad section in .cursorrules:', + '=' * 72, + response, + '=' * 72, + 'End of instruction' + ] + return '\n'.join(sections) + else: + logger.error("Invalid format type", extra={ + "context": { + "format_type": format_type, + "valid_formats": ["text", "json", "markdown"] + } + }) + raise ValidationError("Invalid format type", { + "format_type": format_type, + "valid_formats": ["text", "json", "markdown"] + }) - # Read file content if specified - file_content = None - if args.file: - file_content = read_file_content(args.file) - if file_content is None: - sys.exit(1) +def validate_plan(plan: Dict[str, Any]) -> None: + """ + Validate plan structure. + + Args: + plan: Plan dictionary to validate + + Raises: + ValidationError: If plan structure is invalid + """ + required_keys = ["goal", "steps"] + missing_keys = [k for k in required_keys if k not in plan] + if missing_keys: + raise ValidationError("Missing required plan keys", { + "missing_keys": missing_keys, + "plan": plan + }) + + if not isinstance(plan["steps"], list): + raise ValidationError("Steps must be a list", { + "steps_type": type(plan["steps"]).__name__ + }) + + for i, step in enumerate(plan["steps"]): + if not isinstance(step, dict): + raise ValidationError(f"Step {i} must be a dictionary", { + "step_type": type(step).__name__, + "step_index": i + }) + + step_keys = ["description", "action", "expected_result"] + missing_step_keys = [k for k in step_keys if k not in step] + if missing_step_keys: + raise ValidationError(f"Step {i} missing required keys", { + "missing_keys": missing_step_keys, + "step_index": i, + "step": step + }) - # Query LLM and output response - response = query_llm(plan_content, args.prompt, file_content) - if response: - print('Following is the instruction on how to revise the Multi-Agent Scratchpad section in .cursorrules:') - print('========================================================') - print(response) - print('========================================================') - print('Now please do the actual changes in the .cursorrules file. And then switch to the executor role, and read the content of the file to decide what to do next.') - else: - print("Failed to get response from LLM") +def validate_execution_result(result: Dict[str, Any]) -> None: + """ + Validate execution result structure. + + Args: + result: Result dictionary to validate + + Raises: + ValidationError: If result structure is invalid + """ + required_keys = ["success", "output"] + missing_keys = [k for k in required_keys if k not in result] + if missing_keys: + raise ValidationError("Missing required result keys", { + "missing_keys": missing_keys, + "result": result + }) + + if not isinstance(result["success"], bool): + raise ValidationError("Success must be a boolean", { + "success_type": type(result["success"]).__name__ + }) + +async def execute_plan( + plan: Dict[str, Any], + session: Optional[aiohttp.ClientSession] = None, + timeout: int = 300 +) -> Dict[str, Any]: + """ + Execute a plan using LLM for validation. + + Args: + plan: Plan dictionary with goal and steps + session: Optional aiohttp session to reuse + timeout: Request timeout in seconds + + Returns: + Dict containing execution results + + Raises: + ValidationError: If plan structure is invalid + PlanExecError: If execution fails + """ + # Validate plan structure + validate_plan(plan) + + logger.info("Starting plan execution", extra={ + "context": { + "goal": plan["goal"], + "num_steps": len(plan["steps"]), + "timeout": timeout + } + }) + + start_time = time.time() + results = [] + total_cost = 0.0 + + try: + for i, step in enumerate(plan["steps"]): + step_start = time.time() + logger.info(f"Executing step {i+1}", extra={ + "context": { + "step_index": i, + "description": step["description"], + "action": step["action"] + } + }) + + # Execute step action + try: + # TODO: Implement actual step execution + # This is a placeholder that always succeeds + step_result = { + "success": True, + "output": f"Executed: {step['action']}" + } + + # Validate result structure + validate_execution_result(step_result) + + step_elapsed = time.time() - step_start + results.append({ + "step_index": i, + "description": step["description"], + "action": step["action"], + "expected_result": step["expected_result"], + "success": step_result["success"], + "output": step_result["output"], + "elapsed_time": step_elapsed + }) + + logger.info(f"Step {i+1} completed", extra={ + "context": { + "step_index": i, + "success": step_result["success"], + "elapsed_seconds": step_elapsed + } + }) + + except Exception as e: + step_elapsed = time.time() - step_start + logger.error(f"Step {i+1} failed", extra={ + "context": { + "step_index": i, + "error": str(e), + "elapsed_seconds": step_elapsed + } + }) + + results.append({ + "step_index": i, + "description": step["description"], + "action": step["action"], + "expected_result": step["expected_result"], + "success": False, + "error": str(e), + "elapsed_time": step_elapsed + }) + + raise FileError( + f"Step {i+1} failed: {e}", + "execution", + { + "step_index": i, + "description": step["description"], + "error": str(e) + } + ) + + except Exception as e: + if not isinstance(e, (ValidationError, FileError)): + logger.error("Unexpected error during execution", extra={ + "context": { + "error": str(e), + "elapsed_seconds": time.time() - start_time + } + }) + raise FileError( + "Execution failed", + "unknown", + { + "error": str(e), + "elapsed_seconds": time.time() - start_time + } + ) + raise + + elapsed = time.time() - start_time + + result = { + "goal": plan["goal"], + "steps": results, + "total_steps": len(results), + "successful_steps": sum(1 for r in results if r["success"]), + "failed_steps": sum(1 for r in results if not r["success"]), + "elapsed_time": elapsed + } + + logger.info("Plan execution completed", extra={ + "context": { + "goal": plan["goal"], + "total_steps": result["total_steps"], + "successful_steps": result["successful_steps"], + "failed_steps": result["failed_steps"], + "elapsed_seconds": elapsed + } + }) + + return result + +def main(): + parser = create_parser('Execute a plan using LLM for validation') + parser.add_argument('--plan', type=str, required=True, + help='JSON string or @file containing the plan') + parser.add_argument('--timeout', type=int, default=300, + help='Request timeout in seconds') + + args = parser.parse_args() + + # Configure logging + log_config = get_log_config(args) + logger = setup_logging(__name__, **log_config) + logger.debug("Debug logging enabled", extra={ + "context": { + "log_level": log_config["level"].value, + "log_format": log_config["format_type"].value + } + }) + + try: + # Load plan + if args.plan.startswith('@'): + try: + with open(args.plan[1:], 'r') as f: + plan = json.load(f) + except Exception as e: + raise ValidationError("Failed to load plan file", { + "file": args.plan[1:], + "error": str(e) + }) + else: + try: + plan = json.loads(args.plan) + except json.JSONDecodeError as e: + raise ValidationError("Invalid plan JSON", { + "error": str(e), + "plan": args.plan + }) + + start_time = time.time() + result = asyncio.run(execute_plan( + plan, + timeout=args.timeout + )) + elapsed = time.time() - start_time + + metadata = { + "total_steps": result["total_steps"], + "successful_steps": result["successful_steps"], + "failed_steps": result["failed_steps"], + "elapsed_time": format_duration(elapsed) + } + + print(format_output(result, args.format, "Plan Execution Results", metadata)) + + # Exit with error if any steps failed + if result["failed_steps"] > 0: + sys.exit(1) + + except ValidationError as e: + logger.error("Invalid input", extra={"context": e.context}) + sys.exit(1) + except FileError as e: + logger.error("Execution error", extra={"context": e.context}) + sys.exit(1) + except Exception as e: + logger.error("Processing failed", extra={ + "context": { + "error": str(e) + } + }) sys.exit(1) if __name__ == "__main__": diff --git a/tools/screenshot_utils.py b/tools/screenshot_utils.py index 40d0098..15d93bc 100755 --- a/tools/screenshot_utils.py +++ b/tools/screenshot_utils.py @@ -1,56 +1,282 @@ #!/usr/bin/env python3 import asyncio -from playwright.async_api import async_playwright +from playwright.async_api import async_playwright, Error as PlaywrightError import os import tempfile from pathlib import Path +import sys +from typing import Optional, Dict, Any +from urllib.parse import urlparse +import time +from .common.logging_config import setup_logging +from .common.errors import ToolError, ValidationError, FileError +from .common.formatting import format_output, format_file_size, format_duration +from .common.cli import create_parser, get_log_config + +logger = setup_logging(__name__) + +class ScreenshotError(ToolError): + """Custom exception for screenshot failures""" + def __init__(self, message: str, url: str, details: Optional[Dict[str, Any]] = None): + context = details or {} + context["url"] = url + super().__init__(message, context) + self.url = url + +def validate_url(url: str) -> bool: + """Validate URL format""" + try: + result = urlparse(url) + return all([result.scheme in ('http', 'https'), result.netloc]) + except Exception: + return False + +def validate_dimensions(width: int, height: int) -> tuple[int, int]: + """Validate viewport dimensions""" + if width < 1 or height < 1: + raise ValidationError("Width and height must be positive integers", { + "width": width, + "height": height + }) + if width > 16384 or height > 16384: + raise ValidationError("Width and height cannot exceed 16384 pixels", { + "width": width, + "height": height + }) + return width, height async def take_screenshot(url: str, output_path: str = None, width: int = 1280, height: int = 720) -> str: """ Take a screenshot of a webpage using Playwright. Args: - url (str): The URL to take a screenshot of - output_path (str, optional): Path to save the screenshot. If None, saves to a temporary file. - width (int, optional): Viewport width. Defaults to 1280. - height (int, optional): Viewport height. Defaults to 720. + url: The URL to take a screenshot of + output_path: Path to save the screenshot. If None, saves to a temporary file. + width: Viewport width. Defaults to 1280. + height: Viewport height. Defaults to 720. Returns: str: Path to the saved screenshot + + Raises: + ValidationError: If input parameters are invalid + ScreenshotError: If screenshot capture fails + FileError: If output file cannot be created or written """ + # Validate inputs + if not validate_url(url): + logger.error("Invalid URL format", extra={ + "context": { + "url": url, + "error": "Invalid URL format" + } + }) + raise ValidationError("Invalid URL format", {"url": url}) + + try: + width, height = validate_dimensions(width, height) + except ValidationError as e: + logger.error("Invalid dimensions", extra={ + "context": e.context + }) + raise + + # Prepare output path if output_path is None: - # Create a temporary file with .png extension - temp_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False) - output_path = temp_file.name - temp_file.close() - - async with async_playwright() as p: - browser = await p.chromium.launch(headless=True) - page = await browser.new_page(viewport={'width': width, 'height': height}) - try: - await page.goto(url, wait_until='networkidle') - await page.screenshot(path=output_path, full_page=True) - finally: - await browser.close() + temp_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False) + output_path = temp_file.name + temp_file.close() + logger.debug("Created temporary file", extra={ + "context": { + "path": output_path + } + }) + except OSError as e: + logger.error("Failed to create temporary file", extra={ + "context": { + "error": str(e) + } + }) + raise FileError("Failed to create temporary file", "temp_file", {"error": str(e)}) + else: + # Ensure directory exists + output_dir = os.path.dirname(output_path) + if output_dir: + try: + os.makedirs(output_dir, exist_ok=True) + logger.debug("Created output directory", extra={ + "context": { + "directory": output_dir + } + }) + except OSError as e: + logger.error("Failed to create output directory", extra={ + "context": { + "directory": output_dir, + "error": str(e) + } + }) + raise FileError("Failed to create output directory", output_dir, {"error": str(e)}) - return output_path + logger.info("Taking screenshot", extra={ + "context": { + "url": url, + "output_path": output_path, + "dimensions": f"{width}x{height}" + } + }) + start_time = time.time() + + try: + async with async_playwright() as p: + browser = await p.chromium.launch(headless=True) + try: + page = await browser.new_page(viewport={'width': width, 'height': height}) + + logger.debug("Navigating to page", extra={ + "context": { + "url": url + } + }) + await page.goto(url, wait_until='networkidle') + + logger.debug("Capturing screenshot", extra={ + "context": { + "output_path": output_path + } + }) + await page.screenshot(path=output_path, full_page=True) + + elapsed = time.time() - start_time + file_size = os.path.getsize(output_path) + logger.info("Screenshot captured successfully", extra={ + "context": { + "url": url, + "output_path": output_path, + "elapsed_seconds": elapsed, + "file_size_bytes": file_size + } + }) + return output_path + + except PlaywrightError as e: + logger.error("Playwright error", extra={ + "context": { + "url": url, + "error": str(e), + "elapsed_seconds": time.time() - start_time + } + }) + raise ScreenshotError("Failed to capture screenshot", url, { + "error": str(e), + "elapsed_seconds": time.time() - start_time + }) + finally: + await browser.close() + + except Exception as e: + # Clean up temporary file if we created one and failed + if output_path and os.path.exists(output_path): + try: + os.unlink(output_path) + except OSError: + pass # Ignore cleanup errors + + if isinstance(e, (ScreenshotError, ValidationError, FileError)): + raise + + logger.error("Unexpected error during screenshot", extra={ + "context": { + "url": url, + "error": str(e), + "elapsed_seconds": time.time() - start_time + } + }) + raise ScreenshotError("Screenshot capture failed", url, { + "error": str(e), + "elapsed_seconds": time.time() - start_time + }) def take_screenshot_sync(url: str, output_path: str = None, width: int = 1280, height: int = 720) -> str: """ Synchronous wrapper for take_screenshot. + + Args: + url: The URL to take a screenshot of + output_path: Path to save the screenshot. If None, saves to a temporary file. + width: Viewport width. Defaults to 1280. + height: Viewport height. Defaults to 720. + + Returns: + str: Path to the saved screenshot + + Raises: + ScreenshotError: If screenshot capture fails + ValidationError: If input parameters are invalid """ - return asyncio.run(take_screenshot(url, output_path, width, height)) + try: + return asyncio.run(take_screenshot(url, output_path, width, height)) + except Exception as e: + # Re-raise with appropriate type while preserving the traceback + if isinstance(e, (ScreenshotError, ValidationError, FileError)): + raise + raise ScreenshotError("Screenshot capture failed", url, { + "error": str(e) + }) -if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser(description='Take a screenshot of a webpage') +def main(): + parser = create_parser('Take a screenshot of a webpage') parser.add_argument('url', help='URL to take screenshot of') parser.add_argument('--output', '-o', help='Output path for screenshot') parser.add_argument('--width', '-w', type=int, default=1280, help='Viewport width') parser.add_argument('--height', '-H', type=int, default=720, help='Viewport height') args = parser.parse_args() - output_path = take_screenshot_sync(args.url, args.output, args.width, args.height) - print(f"Screenshot saved to: {output_path}") \ No newline at end of file + + # Configure logging + log_config = get_log_config(args) + logger = setup_logging(__name__, **log_config) + logger.debug("Debug logging enabled", extra={ + "context": { + "log_level": log_config["level"].value, + "log_format": log_config["format_type"].value + } + }) + + try: + path = take_screenshot_sync(args.url, args.output, args.width, args.height) + file_size = os.path.getsize(path) + elapsed = time.time() - start_time if 'start_time' in locals() else 0 + + result = { + "screenshot_path": str(path), + "file_size": format_file_size(file_size), + "elapsed_time": format_duration(elapsed), + "status": "success" + } + + metadata = { + "url": args.url, + "dimensions": f"{args.width}x{args.height}" + } + + print(format_output(result, args.format, "Screenshot Captured", metadata)) + + except ValidationError as e: + logger.error("Invalid input", extra={"context": e.context}) + sys.exit(1) + except (ScreenshotError, FileError) as e: + logger.error(str(e), extra={"context": e.context}) + sys.exit(1) + except Exception as e: + logger.error("Unexpected error", extra={ + "context": { + "error": str(e) + } + }) + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tools/search_engine.py b/tools/search_engine.py index 120544e..22f109c 100755 --- a/tools/search_engine.py +++ b/tools/search_engine.py @@ -1,79 +1,487 @@ #!/usr/bin/env python3 -import argparse import sys import time +from typing import List, Dict, Any, Optional +from abc import ABC, abstractmethod +import requests +from bs4 import BeautifulSoup from duckduckgo_search import DDGS +from googlesearch import search as google_search +from .common.logging_config import setup_logging +from .common.errors import ToolError, ValidationError, APIError +from .common.formatting import format_output, format_duration +from .common.cli import create_parser, get_log_config -def search_with_retry(query, max_results=10, max_retries=3): +logger = setup_logging(__name__) + +class SearchError(APIError): + """Custom exception for search failures""" + def __init__(self, message: str, engine: str, query: str, context: Optional[Dict[str, Any]] = None): + context = context or {} + context["query"] = query + super().__init__(message, engine, context) + self.query = query + +class SearchEngine(ABC): + """Abstract base class for search engines""" + + @abstractmethod + def search(self, query: str, max_results: int = 10, max_retries: int = 3, fetch_snippets: bool = True) -> List[Dict[str, Any]]: + """ + Execute a search query. + + Args: + query: Search query + max_results: Maximum number of results to return + max_retries: Maximum number of retry attempts + fetch_snippets: Whether to fetch page content for snippets + + Returns: + List of dicts containing search results + + Raises: + SearchError: If search operation fails + """ + pass + +def fetch_page_snippet(url: str, query: str, max_retries: int = 3) -> tuple[str, str]: """ - Search using DuckDuckGo and return results with URLs and text snippets. + Fetch a webpage and extract its title and a snippet of content. Args: - query (str): Search query - max_results (int): Maximum number of results to return - max_retries (int): Maximum number of retry attempts + url: URL to fetch + query: The search query that led to this URL + max_retries: Maximum number of retry attempts + + Returns: + tuple: (title, snippet) + + Raises: + SearchError: If page cannot be fetched or parsed """ + start_time = time.time() + headers = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' + } + for attempt in range(max_retries): try: - print(f"DEBUG: Searching for query: {query} (attempt {attempt + 1}/{max_retries})", - file=sys.stderr) + logger.debug(f"Fetching snippet from {url} (attempt {attempt + 1}/{max_retries})") + response = requests.get(url, headers=headers, timeout=10) + response.raise_for_status() - with DDGS() as ddgs: - results = list(ddgs.text(query, max_results=max_results)) - - if not results: - print("DEBUG: No results found", file=sys.stderr) - return [] + soup = BeautifulSoup(response.text, 'html.parser') - print(f"DEBUG: Found {len(results)} results", file=sys.stderr) - return results - + # Get title + title = soup.title.string if soup.title else url + title = title.strip() + + # Get snippet from meta description or first paragraph + snippet = "" + meta_desc = soup.find('meta', attrs={'name': 'description'}) + if meta_desc and meta_desc.get('content'): + snippet = meta_desc['content'].strip() + logger.debug(f"Using meta description for {url} ({len(snippet)} chars)") + else: + # Try to get first few paragraphs + paragraphs = soup.find_all('p') + text_chunks = [] + for p in paragraphs: + text = p.get_text().strip() + if text and len(text) > 50: # Skip short paragraphs + text_chunks.append(text) + if len(' '.join(text_chunks)) > 200: # Get enough text + break + if text_chunks: + snippet = ' '.join(text_chunks)[:300] + '...' # Limit length + logger.debug(f"Using paragraph content for {url} ({len(snippet)} chars)") + else: + logger.warning(f"No suitable content found for snippet in {url}") + + elapsed = time.time() - start_time + logger.info(f"Successfully fetched snippet from {url} in {elapsed:.2f}s") + return title, snippet + + except requests.exceptions.HTTPError as e: + logger.error(f"HTTP error fetching {url} (attempt {attempt + 1}/{max_retries}): {e.response.status_code}") + if attempt < max_retries - 1: + logger.debug("Waiting 1 second before retry...") + time.sleep(1) + else: + raise SearchError(f"HTTP error {e.response.status_code}", "fetch", query, { + "url": url, + "status_code": e.response.status_code + }) + except requests.exceptions.Timeout: + logger.error(f"Timeout fetching {url} (attempt {attempt + 1}/{max_retries})") + if attempt < max_retries - 1: + logger.debug("Waiting 1 second before retry...") + time.sleep(1) + else: + raise SearchError("Request timed out", "fetch", query, { + "url": url, + "timeout": timeout + }) + except requests.exceptions.RequestException as e: + logger.error(f"Network error fetching {url} (attempt {attempt + 1}/{max_retries}): {str(e)}") + if attempt < max_retries - 1: + logger.debug("Waiting 1 second before retry...") + time.sleep(1) + else: + raise SearchError("Network error", "fetch", query, { + "url": url, + "error": str(e) + }) except Exception as e: - print(f"ERROR: Attempt {attempt + 1}/{max_retries} failed: {str(e)}", file=sys.stderr) - if attempt < max_retries - 1: # If not the last attempt - print(f"DEBUG: Waiting 1 second before retry...", file=sys.stderr) - time.sleep(1) # Wait 1 second before retry + logger.error(f"Unexpected error fetching {url} (attempt {attempt + 1}/{max_retries}): {str(e)}") + if attempt < max_retries - 1: + logger.debug("Waiting 1 second before retry...") + time.sleep(1) else: - print(f"ERROR: All {max_retries} attempts failed", file=sys.stderr) - raise + raise SearchError("Failed to fetch page", "fetch", query, { + "url": url, + "error": str(e) + }) + +class DuckDuckGoEngine(SearchEngine): + """DuckDuckGo search implementation""" + + def search(self, query: str, max_results: int = 10, max_retries: int = 3, fetch_snippets: bool = True) -> List[Dict[str, Any]]: + start_time = time.time() + logger.info(f"Starting DuckDuckGo search for: {query}") + logger.debug(f"Search parameters: max_results={max_results}, max_retries={max_retries}") + + for attempt in range(max_retries): + try: + logger.debug(f"Search attempt {attempt + 1}/{max_retries}") + + with DDGS() as ddgs: + results = list(ddgs.text(query, max_results=max_results)) + + if not results: + logger.debug("No results found") + return [] + + result_count = len(results) + logger.debug(f"Found {result_count} results") + + # Normalize result format + formatted_results = [] + for i, r in enumerate(results, 1): + formatted_result = { + "url": r.get("href", ""), + "title": r.get("title", ""), + "snippet": r.get("body", "") + } + + # Validate and clean up result + if not formatted_result["url"]: + logger.warning(f"Result {i}/{result_count} missing URL, skipping") + continue + + if not formatted_result["title"]: + logger.debug(f"Result {i}/{result_count} missing title, using URL") + formatted_result["title"] = formatted_result["url"] + + if not formatted_result["snippet"]: + logger.debug(f"Result {i}/{result_count} missing snippet") + formatted_result["snippet"] = "" + + logger.debug(f"Result {i}/{result_count}: {formatted_result['url']}") + formatted_results.append(formatted_result) + + elapsed = time.time() - start_time + logger.info(f"Search completed in {elapsed:.2f}s with {len(formatted_results)} valid results") + return formatted_results + + except Exception as e: + logger.error(f"Attempt {attempt + 1}/{max_retries} failed: {str(e)}") + if attempt == max_retries - 1: + raise SearchError("DuckDuckGo search failed", "duckduckgo", query, { + "error": str(e), + "attempts": max_retries + }) + time.sleep(1) + continue + +class GoogleEngine(SearchEngine): + """Google search implementation""" + + def search(self, query: str, max_results: int = 10, max_retries: int = 3, fetch_snippets: bool = True) -> List[Dict[str, Any]]: + start_time = time.time() + logger.info(f"Starting Google search for: {query}") + logger.debug(f"Search parameters: max_results={max_results}, max_retries={max_retries}, fetch_snippets={fetch_snippets}") + + for attempt in range(max_retries): + try: + logger.debug(f"Search attempt {attempt + 1}/{max_retries}") + + results = [] + urls = list(google_search(query, num_results=max_results)) + + if not urls: + logger.info("No results found") + return [] + + logger.debug(f"Found {len(urls)} URLs") + + for i, url in enumerate(urls, 1): + try: + if fetch_snippets: + logger.debug(f"Fetching content for result {i}/{len(urls)}: {url}") + title, snippet = fetch_page_snippet(url, query, max_retries) + else: + logger.debug(f"Skipping content fetch for result {i}/{len(urls)}: {url}") + title, snippet = url, "" + + result = { + "url": url, + "title": title, + "snippet": snippet + } + logger.debug(f"Result {i}/{len(urls)}: {url}") + results.append(result) + + except SearchError as e: + logger.warning(f"Failed to fetch snippet for {url}: {str(e)}") + # Include the result anyway, just without a snippet + results.append({ + "url": url, + "title": url, + "snippet": "" + }) + + if not results: + logger.info("No valid results found") + return [] + + elapsed = time.time() - start_time + logger.info(f"Search completed in {elapsed:.2f}s with {len(results)} results") + return results + + except Exception as e: + logger.error(f"Search attempt {attempt + 1}/{max_retries} failed: {str(e)}") + if attempt < max_retries - 1: + logger.debug("Waiting 1 second before retry...") + time.sleep(1) + else: + raise SearchError("Google search failed", "google", query, { + "error": str(e), + "attempts": max_retries + }) -def format_results(results): - """Format and print search results.""" - for i, r in enumerate(results, 1): - print(f"\n=== Result {i} ===") - print(f"URL: {r.get('href', 'N/A')}") - print(f"Title: {r.get('title', 'N/A')}") - print(f"Snippet: {r.get('body', 'N/A')}") +def get_search_engine(engine: str = "duckduckgo") -> SearchEngine: + """ + Get a search engine instance by name. + + Args: + engine: Name of the search engine ('duckduckgo' or 'google') + + Returns: + SearchEngine instance + + Raises: + ValidationError: If engine name is invalid + """ + engines = { + "duckduckgo": DuckDuckGoEngine, + "google": GoogleEngine + } + + engine_class = engines.get(engine.lower()) + if not engine_class: + raise ValidationError("Invalid search engine", { + "engine": engine, + "valid_options": list(engines.keys()) + }) + + return engine_class() -def search(query, max_results=10, max_retries=3): +def validate_query(query: str) -> str: """ - Main search function that handles search with retry mechanism. + Validate and normalize search query. Args: - query (str): Search query - max_results (int): Maximum number of results to return - max_retries (int): Maximum number of retry attempts + query: Search query string + + Returns: + str: Normalized query + + Raises: + ValidationError: If query is invalid """ - try: - results = search_with_retry(query, max_results, max_retries) - if results: - format_results(results) + if not query or not query.strip(): + raise ValidationError("Search query cannot be empty") + return query.strip() + +def validate_max_results(max_results: int) -> int: + """ + Validate and normalize max_results parameter. + + Args: + max_results: Maximum number of results to return + + Returns: + int: Normalized max_results value + + Raises: + ValidationError: If max_results is invalid + """ + if max_results < 1: + raise ValidationError("max_results must be a positive integer", { + "max_results": max_results + }) + if max_results > 100: # Reasonable upper limit + raise ValidationError("max_results cannot exceed 100", { + "max_results": max_results, + "max_allowed": 100 + }) + return max_results + +def search(query: str, max_results: int = 10, max_retries: int = 3, engine: str = "duckduckgo", fetch_snippets: bool = True) -> List[Dict[str, Any]]: + """ + Search using the specified engine and return results with URLs and text snippets. + + Args: + query: Search query + max_results: Maximum number of results to return (1-100) + max_retries: Maximum number of retry attempts + engine: Search engine to use ('duckduckgo' or 'google') + fetch_snippets: Whether to fetch page content for snippets (Google only) + + Returns: + List of dicts containing search results with keys: + - url: Result URL + - title: Result title + - snippet: Text snippet + Raises: + ValidationError: If input parameters are invalid + SearchError: If search operation fails + """ + # Validate inputs + try: + query = validate_query(query) + max_results = validate_max_results(max_results) + except ValidationError as e: + logger.error("Invalid input parameters", extra={ + "context": e.context + }) + raise + + logger.info(f"Searching for: {query}", extra={ + "context": { + "engine": engine, + "max_results": max_results, + "max_retries": max_retries, + "fetch_snippets": fetch_snippets + } + }) + + start_time = time.time() + try: + # Get search engine instance + search_engine = get_search_engine(engine) + + # Perform search + results = search_engine.search(query, max_results, max_retries, fetch_snippets) + + # Log results + elapsed = time.time() - start_time + logger.info(f"Search completed in {elapsed:.2f}s", extra={ + "context": { + "engine": engine, + "result_count": len(results), + "elapsed_seconds": elapsed + } + }) + + return results + + except ValidationError as e: + logger.error("Invalid search engine", extra={ + "context": e.context + }) + raise + except SearchError as e: + logger.error("Search failed", extra={ + "context": e.context + }) + raise except Exception as e: - print(f"ERROR: Search failed: {str(e)}", file=sys.stderr) - sys.exit(1) + logger.error("Unexpected error during search", extra={ + "context": { + "engine": engine, + "query": query, + "error": str(e) + } + }) + raise SearchError("Search failed", engine, query, { + "error": str(e) + }) def main(): - parser = argparse.ArgumentParser(description="Search using DuckDuckGo API") + parser = create_parser("Search using various search engines") parser.add_argument("query", help="Search query") + parser.add_argument("--engine", choices=["duckduckgo", "google"], default="duckduckgo", + help="Search engine to use") parser.add_argument("--max-results", type=int, default=10, - help="Maximum number of results (default: 10)") + help="Maximum number of results") parser.add_argument("--max-retries", type=int, default=3, - help="Maximum number of retry attempts (default: 3)") + help="Maximum number of retry attempts") + parser.add_argument("--no-fetch-snippets", action="store_true", + help="Don't fetch page content for snippets (Google only)") args = parser.parse_args() - search(args.query, args.max_results, args.max_retries) + + # Configure logging + log_config = get_log_config(args) + logger = setup_logging(__name__, **log_config) + logger.debug("Debug logging enabled", extra={ + "context": { + "log_level": log_config["level"].value, + "log_format": log_config["format_type"].value + } + }) + + try: + start_time = time.time() + results = search( + args.query, + args.max_results, + args.max_retries, + args.engine, + not args.no_fetch_snippets + ) + elapsed = time.time() - start_time + + metadata = { + "engine": args.engine, + "query": args.query, + "elapsed_time": format_duration(elapsed), + "result_count": len(results) + } + + print(format_output(results, args.format, "Search Results", metadata)) + + # Exit with error if no results found + if not results: + sys.exit(1) + + except ValidationError as e: + logger.error("Invalid input", extra={"context": e.context}) + sys.exit(1) + except SearchError as e: + logger.error(str(e), extra={"context": e.context}) + sys.exit(1) + except Exception as e: + logger.error("Unexpected error", extra={ + "context": { + "error": str(e) + } + }) + sys.exit(1) if __name__ == "__main__": main() diff --git a/tools/token_tracker.py b/tools/token_tracker.py index 5213d73..7593871 100644 --- a/tools/token_tracker.py +++ b/tools/token_tracker.py @@ -4,23 +4,55 @@ import time import json import argparse -from dataclasses import dataclass -from typing import Optional, Dict, List +from dataclasses import dataclass, asdict +from typing import Optional, Dict, List, Any from pathlib import Path import uuid import sys +import tempfile +import shutil from tabulate import tabulate from datetime import datetime +from .common.logging_config import setup_logging, LogLevel, LogFormat +from .common.errors import ToolError, ValidationError, FileError +from .common.formatting import format_output, format_cost, format_duration + +logger = setup_logging(__name__) @dataclass class TokenUsage: + """Token usage information for an API request""" prompt_tokens: int completion_tokens: int total_tokens: int reasoning_tokens: Optional[int] = None + def __post_init__(self): + """Validate token counts""" + if self.prompt_tokens < 0: + raise ValidationError("prompt_tokens must be non-negative", { + "field": "prompt_tokens", + "value": self.prompt_tokens + }) + if self.completion_tokens < 0: + raise ValidationError("completion_tokens must be non-negative", { + "field": "completion_tokens", + "value": self.completion_tokens + }) + if self.total_tokens < 0: + raise ValidationError("total_tokens must be non-negative", { + "field": "total_tokens", + "value": self.total_tokens + }) + if self.reasoning_tokens is not None and self.reasoning_tokens < 0: + raise ValidationError("reasoning_tokens must be non-negative", { + "field": "reasoning_tokens", + "value": self.reasoning_tokens + }) + @dataclass class APIResponse: + """API response information""" content: str token_usage: TokenUsage cost: float @@ -28,42 +60,106 @@ class APIResponse: provider: str = "openai" model: str = "unknown" + def __post_init__(self): + """Validate response data""" + if not self.content: + raise ValidationError("content cannot be empty") + if self.cost < 0: + raise ValidationError("cost must be non-negative", { + "cost": self.cost + }) + if self.thinking_time < 0: + raise ValidationError("thinking_time must be non-negative", { + "thinking_time": self.thinking_time + }) + if not self.provider: + raise ValidationError("provider cannot be empty") + if not self.model: + raise ValidationError("model cannot be empty") + class TokenTracker: + """Track token usage and costs for API requests""" + def __init__(self, session_id: Optional[str] = None, logs_dir: Optional[Path] = None): - # If no session_id provided, use today's date - self.session_id = session_id or datetime.now().strftime("%Y-%m-%d") - self.session_start = time.time() - self.requests: List[Dict] = [] - - # Create logs directory if it doesn't exist - self._logs_dir = logs_dir or Path("token_logs") - self._logs_dir.mkdir(exist_ok=True) - - # Initialize session file - self._session_file = self._logs_dir / f"session_{self.session_id}.json" - - # Load existing session data if file exists - if self._session_file.exists(): - try: - with open(self._session_file, 'r') as f: - data = json.load(f) - self.session_start = data.get('start_time', self.session_start) - self.requests = data.get('requests', []) - except Exception as e: - print(f"Error loading existing session file: {e}", file=sys.stderr) + """ + Initialize token tracker. - self._save_session() + Args: + session_id: Optional session identifier + logs_dir: Optional directory for log files + """ + self._session_id = session_id or str(int(time.time())) + self._logs_dir = logs_dir or Path.home() / '.cursorrules' / 'logs' + self._requests: List[Dict[str, Any]] = [] + self._session_start = time.time() + + logger.debug("Initializing token tracker", extra={ + "context": { + "session_id": self._session_id, + "logs_dir": str(self._logs_dir) + } + }) + + try: + # Create logs directory if it doesn't exist + self._logs_dir.mkdir(parents=True, exist_ok=True) + + # Initialize session file + self._session_file = self._logs_dir / f"session_{self._session_id}.json" + + # Only load existing session if it matches our session ID + if self._session_file.exists() and session_id: + logger.debug("Loading existing session", extra={ + "context": { + "session_file": str(self._session_file) + } + }) + session_data = load_session(self._session_file) + if session_data and session_data.get('session_id') == self._session_id: + self._requests = session_data.get('requests', []) + self._session_start = session_data.get('start_time', self._session_start) + logger.info("Loaded existing session", extra={ + "context": { + "request_count": len(self._requests) + } + }) + except Exception as e: + logger.error("Failed to initialize token tracker", extra={ + "context": { + "session_id": self._session_id, + "logs_dir": str(self._logs_dir), + "error": str(e) + } + }) + raise def _save_session(self): """Save current session data to file""" - session_data = { - "session_id": self.session_id, - "start_time": self.session_start, - "requests": self.requests, - "summary": self.get_session_summary() - } - with open(self._session_file, "w") as f: - json.dump(session_data, f, indent=2) + try: + session_data = { + "session_id": self._session_id, + "start_time": self._session_start, + "requests": self._requests, + "summary": self.get_session_summary() + } + + with open(self._session_file, 'w') as f: + json.dump(session_data, f, indent=2) + + logger.debug("Saved session data", extra={ + "context": { + "session_file": str(self._session_file), + "request_count": len(self._requests) + } + }) + except Exception as e: + logger.error("Failed to save session data", extra={ + "context": { + "session_file": str(self._session_file), + "error": str(e) + } + }) + raise @property def logs_dir(self) -> Path: @@ -72,10 +168,27 @@ def logs_dir(self) -> Path: @logs_dir.setter def logs_dir(self, path: Path): - """Set the logs directory path and update session file path""" - self._logs_dir = path - self._logs_dir.mkdir(exist_ok=True) - self.session_file = self._logs_dir / f"session_{self.session_id}.json" + """ + Set the logs directory path and update session file path. + + Args: + path: New logs directory path + + Raises: + ValueError: If path is invalid + OSError: If directory cannot be created + """ + if not path: + raise ValueError("logs_dir path cannot be empty") + + logger.info(f"Changing logs directory to {path}") + try: + self._logs_dir = path + self._logs_dir.mkdir(exist_ok=True) + self._session_file = self._logs_dir / f"session_{self._session_id}.json" + except Exception as e: + logger.error(f"Failed to set logs directory to {path}: {e}") + raise @property def session_file(self) -> Path: @@ -84,67 +197,131 @@ def session_file(self) -> Path: @session_file.setter def session_file(self, path: Path): - """Set the session file path and load data if it exists""" + """ + Set the session file path and load data if it exists. + + Args: + path: New session file path + + Raises: + ValueError: If path is invalid + OSError: If file operations fail + """ + if not path: + raise ValueError("session_file path cannot be empty") + + logger.info(f"Changing session file to {path}") old_file = self._session_file self._session_file = path - # If we have data and the new file doesn't exist, save our data - if old_file.exists() and not path.exists() and self.requests: - self._save_session() - # If the new file exists, load its data - elif path.exists(): - try: + try: + # If we have data and the new file doesn't exist, save our data + if old_file.exists() and not path.exists() and self._requests: + logger.debug("Saving existing data to new session file") + self._save_session() + # If the new file exists, load its data + elif path.exists(): + logger.debug("Loading data from existing session file") with open(path, 'r') as f: data = json.load(f) - self.session_start = data.get('start_time', self.session_start) - self.requests = data.get('requests', []) - except Exception as e: - print(f"Error loading existing session file: {e}", file=sys.stderr) + self._requests = data.get('requests', []) + logger.info(f"Loaded {len(self._requests)} requests from {path}") + except Exception as e: + logger.error(f"Failed to handle session file change to {path}: {e}") + raise @staticmethod def calculate_openai_cost(prompt_tokens: int, completion_tokens: int, model: str) -> float: - """Calculate OpenAI API cost based on model and token usage""" - # Only support o1, gpt-4o, and deepseek-chat models - if model == "o1": - # o1 pricing per 1M tokens - INPUT_PRICE_PER_M = 15.0 - OUTPUT_PRICE_PER_M = 60.0 - elif model == "gpt-4o": - # gpt-4o pricing per 1M tokens - INPUT_PRICE_PER_M = 10.0 - OUTPUT_PRICE_PER_M = 30.0 - elif model == "deepseek-chat": - # DeepSeek pricing per 1M tokens - INPUT_PRICE_PER_M = 0.2 # $0.20 per million input tokens - OUTPUT_PRICE_PER_M = 0.2 # $0.20 per million output tokens - else: - raise ValueError(f"Unsupported OpenAI model for cost calculation: {model}. Only o1, gpt-4o, and deepseek-chat are supported.") - - input_cost = (prompt_tokens / 1_000_000) * INPUT_PRICE_PER_M - output_cost = (completion_tokens / 1_000_000) * OUTPUT_PRICE_PER_M - return input_cost + output_cost + """ + Calculate cost for OpenAI API usage. + + Args: + prompt_tokens: Number of prompt tokens + completion_tokens: Number of completion tokens + model: Model name + + Returns: + float: Cost in USD + """ + # Cost per 1K tokens (as of March 2024) + costs = { + "gpt-4": {"prompt": 0.03, "completion": 0.06}, + "gpt-4-turbo": {"prompt": 0.01, "completion": 0.03}, + "gpt-4o": {"prompt": 0.01, "completion": 0.03}, + "gpt-4o-ms": {"prompt": 0.01, "completion": 0.03}, + "gpt-3.5-turbo": {"prompt": 0.0005, "completion": 0.0015}, + "o1": {"prompt": 0.01, "completion": 0.03} + } + + model_costs = costs.get(model, costs["gpt-3.5-turbo"]) # Default to gpt-3.5-turbo costs + cost = (prompt_tokens * model_costs["prompt"] + completion_tokens * model_costs["completion"]) / 1000 + + logger.debug("Calculated OpenAI cost", extra={ + "context": { + "model": model, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "cost": cost + } + }) + + return cost @staticmethod def calculate_claude_cost(prompt_tokens: int, completion_tokens: int, model: str) -> float: - """Calculate Claude API cost based on model and token usage""" - # Claude-3 Sonnet pricing per 1M tokens - # Source: https://www.anthropic.com/claude/sonnet - if model in ["claude-3-5-sonnet-20241022", "claude-3-sonnet-20240229"]: - INPUT_PRICE_PER_M = 3.0 # $3 per million input tokens - OUTPUT_PRICE_PER_M = 15.0 # $15 per million output tokens - else: - raise ValueError(f"Unsupported Claude model for cost calculation: {model}. Only claude-3-5-sonnet-20241022 and claude-3-sonnet-20240229 are supported.") - - input_cost = (prompt_tokens / 1_000_000) * INPUT_PRICE_PER_M - output_cost = (completion_tokens / 1_000_000) * OUTPUT_PRICE_PER_M - return input_cost + output_cost + """ + Calculate cost for Claude API usage. + + Args: + prompt_tokens: Number of prompt tokens + completion_tokens: Number of completion tokens + model: Model name + + Returns: + float: Cost in USD + """ + # Cost per 1M tokens (as of March 2024) + costs = { + "claude-3-opus-20240229": {"prompt": 15.0, "completion": 75.0}, + "claude-3-sonnet-20240229": {"prompt": 3.0, "completion": 15.0}, + "claude-3-haiku-20240307": {"prompt": 0.25, "completion": 1.25}, + "claude-3-5-sonnet-20241022": {"prompt": 3.0, "completion": 15.0} + } + + model_costs = costs.get(model, costs["claude-3-sonnet-20240229"]) # Default to sonnet costs + cost = (prompt_tokens * model_costs["prompt"] + completion_tokens * model_costs["completion"]) / 1_000_000 + + logger.debug("Calculated Claude cost", extra={ + "context": { + "model": model, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "cost": cost + } + }) + + return cost def track_request(self, response: APIResponse): - """Track a new API request""" + """ + Track a new API request. + + Args: + response: API response information + + Raises: + ValueError: If response is invalid + """ + # Validate response + if not response: + raise ValueError("response cannot be None") + # Only track costs for OpenAI and Anthropic if response.provider.lower() not in ["openai", "anthropic"]: + logger.debug(f"Skipping cost tracking for unsupported provider: {response.provider}") return - + + logger.debug(f"Tracking request for {response.provider} model {response.model}") request_data = { "timestamp": time.time(), "provider": response.provider, @@ -158,20 +335,33 @@ def track_request(self, response: APIResponse): "cost": response.cost, "thinking_time": response.thinking_time } - self.requests.append(request_data) - self._save_session() + self._requests.append(request_data) + + try: + self._save_session() + logger.info(f"Request tracked successfully. Total requests: {len(self._requests)}") + except Exception as e: + logger.error(f"Failed to save session after tracking request: {e}") + raise def get_session_summary(self) -> Dict: - """Get summary of token usage and costs for the current session""" - total_prompt_tokens = sum(r["token_usage"]["prompt_tokens"] for r in self.requests) - total_completion_tokens = sum(r["token_usage"]["completion_tokens"] for r in self.requests) - total_tokens = sum(r["token_usage"]["total_tokens"] for r in self.requests) - total_cost = sum(r["cost"] for r in self.requests) - total_thinking_time = sum(r["thinking_time"] for r in self.requests) + """ + Get summary of token usage and costs for the current session. + + Returns: + Dict containing session statistics + """ + logger.debug("Generating session summary") + + total_prompt_tokens = sum(r["token_usage"]["prompt_tokens"] for r in self._requests) + total_completion_tokens = sum(r["token_usage"]["completion_tokens"] for r in self._requests) + total_tokens = sum(r["token_usage"]["total_tokens"] for r in self._requests) + total_cost = sum(r["cost"] for r in self._requests) + total_thinking_time = sum(r["thinking_time"] for r in self._requests) # Group by provider provider_stats = {} - for r in self.requests: + for r in self._requests: provider = r["provider"] if provider not in provider_stats: provider_stats[provider] = { @@ -183,164 +373,166 @@ def get_session_summary(self) -> Dict: provider_stats[provider]["total_tokens"] += r["token_usage"]["total_tokens"] provider_stats[provider]["total_cost"] += r["cost"] - return { - "total_requests": len(self.requests), + summary = { + "total_requests": len(self._requests), "total_prompt_tokens": total_prompt_tokens, "total_completion_tokens": total_completion_tokens, "total_tokens": total_tokens, "total_cost": total_cost, "total_thinking_time": total_thinking_time, "provider_stats": provider_stats, - "session_duration": time.time() - self.session_start + "session_duration": time.time() - self._session_start } + + logger.debug(f"Session summary: {len(self._requests)} requests, {total_tokens} tokens, ${total_cost:.6f}") + return summary # Global token tracker instance _token_tracker: Optional[TokenTracker] = None def get_token_tracker(session_id: Optional[str] = None, logs_dir: Optional[Path] = None) -> TokenTracker: - """Get or create a global token tracker instance""" + """ + Get or create a global token tracker instance. + + Args: + session_id: Optional session identifier + logs_dir: Optional directory for log files + + Returns: + TokenTracker: Global token tracker instance + + Raises: + ValueError: If input parameters are invalid + OSError: If log directory cannot be created or accessed + """ global _token_tracker current_date = datetime.now().strftime("%Y-%m-%d") + logger.debug(f"Getting token tracker (session_id={session_id}, logs_dir={logs_dir})") + # If no tracker exists, create one if _token_tracker is None: + logger.debug("Creating new token tracker") _token_tracker = TokenTracker(session_id or current_date, logs_dir=logs_dir) return _token_tracker # If no session_id provided, reuse current tracker if session_id is None: + logger.debug("Reusing existing token tracker") if logs_dir is not None: _token_tracker.logs_dir = logs_dir return _token_tracker # If session_id matches current tracker, reuse it - if session_id == _token_tracker.session_id: + if session_id == _token_tracker._session_id: + logger.debug("Reusing existing token tracker with matching session_id") if logs_dir is not None: _token_tracker.logs_dir = logs_dir return _token_tracker # Otherwise, create a new tracker + logger.debug("Creating new token tracker with different session_id") _token_tracker = TokenTracker(session_id, logs_dir=logs_dir) return _token_tracker -# Viewing functionality (moved from view_usage.py) -def format_cost(cost: float) -> str: - """Format a cost value in dollars""" - return f"${cost:.6f}" - -def format_duration(seconds: float) -> str: - """Format duration in a human-readable format""" - if seconds < 60: - return f"{seconds:.2f}s" - minutes = seconds / 60 - if minutes < 60: - return f"{minutes:.2f}m" - hours = minutes / 60 - return f"{hours:.2f}h" - def load_session(session_file: Path) -> Optional[Dict]: - """Load a session file and return its contents""" + """ + Load session data from a file. + + Args: + session_file: Path to session file + + Returns: + Optional[Dict]: Session data or None if file doesn't exist + + Raises: + FileError: If file exists but cannot be read + """ + if not session_file.exists(): + return None + try: with open(session_file, 'r') as f: - return json.load(f) + data = json.load(f) + logger.debug("Loaded session file", extra={ + "context": { + "path": str(session_file), + "session_id": data.get("session_id") + } + }) + return data except Exception as e: - print(f"Error loading session file {session_file}: {e}", file=sys.stderr) - return None + logger.error("Failed to load session file", extra={ + "context": { + "path": str(session_file), + "error": str(e) + } + }) + raise FileError("Failed to load session file", str(session_file), { + "error": str(e) + }) -def display_session_summary(session_data: Dict, show_requests: bool = False): - """Display a summary of the session""" - summary = session_data["summary"] - - # Print session overview - print("\nSession Overview") - print("===============") - print(f"Session ID: {session_data['session_id']}") - print(f"Duration: {format_duration(summary['session_duration'])}") - print(f"Total Requests: {summary['total_requests']}") - print(f"Total Cost: {format_cost(summary['total_cost'])}") +def main(): + parser = argparse.ArgumentParser( + description='Track and analyze token usage across LLM requests', + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument('--session', help='Session ID to analyze') + parser.add_argument('--logs-dir', type=Path, help='Directory for token tracking logs') - # Print token usage - print("\nToken Usage") - print("===========") - print(f"Prompt Tokens: {summary['total_prompt_tokens']:,}") - print(f"Completion Tokens: {summary['total_completion_tokens']:,}") - print(f"Total Tokens: {summary['total_tokens']:,}") + # Add output format options + parser.add_argument('--format', + choices=['text', 'json', 'markdown'], + default='text', + help='Output format') - # Print provider stats - print("\nProvider Statistics") - print("==================") - provider_data = [] - for provider, stats in summary["provider_stats"].items(): - provider_data.append([ - provider, - stats["requests"], - f"{stats['total_tokens']:,}", - format_cost(stats["total_cost"]) - ]) - print(tabulate( - provider_data, - headers=["Provider", "Requests", "Tokens", "Cost"], - tablefmt="simple" - )) + # Add mutually exclusive logging options + log_group = parser.add_mutually_exclusive_group() + log_group.add_argument('--log-level', + choices=['debug', 'info', 'warning', 'error', 'quiet'], + default='info', + help='Set the logging level') + log_group.add_argument('--debug', + action='store_true', + help='Enable debug logging (equivalent to --log-level debug)') + log_group.add_argument('--quiet', + action='store_true', + help='Minimize output (equivalent to --log-level quiet)') - # Print individual requests if requested - if show_requests: - print("\nIndividual Requests") - print("==================") - request_data = [] - for req in session_data["requests"]: - request_data.append([ - req["provider"], - req["model"], - f"{req['token_usage']['total_tokens']:,}", - format_cost(req["cost"]), - f"{req['thinking_time']:.2f}s" - ]) - print(tabulate( - request_data, - headers=["Provider", "Model", "Tokens", "Cost", "Time"], - tablefmt="simple" - )) - -def list_sessions(logs_dir: Path): - """List all available session files""" - session_files = sorted(logs_dir.glob("session_*.json")) - if not session_files: - print("No session files found.") - return - - for session_file in session_files: - session_data = load_session(session_file) - if session_data: - summary = session_data["summary"] - print(f"\nSession: {session_data['session_id']}") - print(f"Duration: {format_duration(summary['session_duration'])}") - print(f"Requests: {summary['total_requests']}") - print(f"Total Cost: {format_cost(summary['total_cost'])}") - print(f"Total Tokens: {summary['total_tokens']:,}") - -def main(): - parser = argparse.ArgumentParser(description='View LLM API usage statistics') - parser.add_argument('--session', type=str, help='Session ID to view details for') - parser.add_argument('--requests', action='store_true', help='Show individual requests') args = parser.parse_args() + + # Configure logging + log_config = get_log_config(args) + logger = setup_logging(__name__, **log_config) + logger.debug("Debug logging enabled", extra={ + "context": { + "log_level": log_config["level"].value, + "log_format": log_config["format_type"].value + } + }) - logs_dir = Path("token_logs") - if not logs_dir.exists(): - print("No logs directory found") - return - - if args.session: - session_file = logs_dir / f"session_{args.session}.json" - if not session_file.exists(): - print(f"Session file not found: {session_file}") - return + try: + tracker = get_token_tracker(args.session, args.logs_dir) + summary = tracker.get_session_summary() - session_data = load_session(session_file) - if session_data: - display_session_summary(session_data, args.requests) - else: - list_sessions(logs_dir) + metadata = { + "session_id": tracker._session_id, + "session_file": str(tracker.session_file), + "start_time": format_timestamp(tracker._session_start) + } + + print(format_output(summary, args.format, "Token Usage Summary", metadata)) + + except (ValidationError, FileError) as e: + logger.error(str(e), extra={"context": e.context}) + sys.exit(1) + except Exception as e: + logger.error("Unexpected error", extra={ + "context": { + "error": str(e) + } + }) + sys.exit(1) if __name__ == "__main__": main() \ No newline at end of file diff --git a/tools/web_scraper.py b/tools/web_scraper.py index fde66a0..ad45f0d 100755 --- a/tools/web_scraper.py +++ b/tools/web_scraper.py @@ -1,189 +1,467 @@ -#!/usr/bin/env /workspace/tmp_windsurf/venv/bin/python3 +#!/usr/bin/env python3 import asyncio -import argparse import sys import os -from typing import List, Optional +from typing import List, Optional, Dict, Any, Tuple from playwright.async_api import async_playwright import html5lib from multiprocessing import Pool import time from urllib.parse import urlparse -import logging import aiohttp +from .common.logging_config import setup_logging +from .common.errors import ToolError, ValidationError, APIError +from .common.formatting import format_output, format_duration +from .common.cli import create_parser, get_log_config -# Configure logging -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s', - stream=sys.stderr -) -logger = logging.getLogger(__name__) +logger = setup_logging(__name__) -async def fetch_page(url: str, session: Optional[aiohttp.ClientSession] = None) -> Optional[str]: - """Asynchronously fetch a webpage's content.""" - if session is None: - async with aiohttp.ClientSession() as session: - try: - logger.info(f"Fetching {url}") - async with session.get(url) as response: +class FetchError(APIError): + """Custom exception for fetch failures""" + def __init__(self, message: str, url: str, status_code: Optional[int] = None, context: Optional[Dict[str, Any]] = None): + context = context or {} + context["url"] = url + if status_code is not None: + context["status_code"] = status_code + super().__init__(message, "fetch", context) + self.url = url + self.status_code = status_code + +def validate_url(url: str) -> bool: + """ + Validate if a string is a valid URL. + + Args: + url: URL to validate + + Returns: + bool: True if URL is valid + """ + try: + result = urlparse(url) + return all([result.scheme in ('http', 'https'), result.netloc]) + except: + return False + +def validate_max_concurrent(max_concurrent: int) -> int: + """ + Validate and normalize max_concurrent parameter. + + Args: + max_concurrent: Maximum number of concurrent requests + + Returns: + int: Normalized max_concurrent value + + Raises: + ValidationError: If max_concurrent is invalid + """ + if max_concurrent < 1: + raise ValidationError("max_concurrent must be a positive integer", { + "max_concurrent": max_concurrent + }) + if max_concurrent > 20: # Reasonable upper limit + raise ValidationError("max_concurrent cannot exceed 20", { + "max_concurrent": max_concurrent, + "max_allowed": 20 + }) + return max_concurrent + +async def fetch_page(url: str, session: Optional[aiohttp.ClientSession] = None, timeout: int = 30) -> str: + """ + Asynchronously fetch a webpage's content. + + Args: + url: URL to fetch + session: Optional aiohttp session to reuse + timeout: Request timeout in seconds + + Returns: + str: Page content + + Raises: + ValidationError: If URL is invalid + FetchError: If the page cannot be fetched + """ + if not validate_url(url): + logger.error("Invalid URL format", extra={ + "context": { + "url": url, + "error": "Invalid URL format" + } + }) + raise ValidationError("Invalid URL format", {"url": url}) + + logger.info("Fetching page", extra={ + "context": { + "url": url, + "timeout": timeout, + "reuse_session": session is not None + } + }) + + headers = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' + } + + start_time = time.time() + try: + if session is None: + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=headers, timeout=timeout) as response: if response.status == 200: content = await response.text() - logger.info(f"Successfully fetched {url}") + elapsed = time.time() - start_time + logger.info("Successfully fetched page", extra={ + "context": { + "url": url, + "elapsed_seconds": elapsed, + "content_length": len(content) + } + }) return content else: - logger.error(f"Error fetching {url}: HTTP {response.status}") - return None - except Exception as e: - logger.error(f"Error fetching {url}: {str(e)}") - return None - else: - try: - logger.info(f"Fetching {url}") - response = await session.get(url) - if response.status == 200: - content = await response.text() - logger.info(f"Successfully fetched {url}") - return content - else: - logger.error(f"Error fetching {url}: HTTP {response.status}") - return None - except Exception as e: - logger.error(f"Error fetching {url}: {str(e)}") - return None + logger.error("HTTP error", extra={ + "context": { + "url": url, + "status_code": response.status, + "elapsed_seconds": time.time() - start_time + } + }) + raise FetchError("HTTP error", url, response.status) + else: + async with session.get(url, headers=headers, timeout=timeout) as response: + if response.status == 200: + content = await response.text() + elapsed = time.time() - start_time + logger.info("Successfully fetched page", extra={ + "context": { + "url": url, + "elapsed_seconds": elapsed, + "content_length": len(content) + } + }) + return content + else: + logger.error("HTTP error", extra={ + "context": { + "url": url, + "status_code": response.status, + "elapsed_seconds": time.time() - start_time + } + }) + raise FetchError("HTTP error", url, response.status) + except asyncio.TimeoutError: + elapsed = time.time() - start_time + logger.error("Request timeout", extra={ + "context": { + "url": url, + "timeout": timeout, + "elapsed_seconds": elapsed + } + }) + raise FetchError("Request timed out", url, None, { + "timeout": timeout, + "elapsed_seconds": elapsed + }) + except aiohttp.ClientError as e: + logger.error("Network error", extra={ + "context": { + "url": url, + "error": str(e), + "elapsed_seconds": time.time() - start_time + } + }) + raise FetchError("Network error", url, None, { + "error": str(e), + "elapsed_seconds": time.time() - start_time + }) + except Exception as e: + logger.error("Unexpected error", extra={ + "context": { + "url": url, + "error": str(e), + "elapsed_seconds": time.time() - start_time + } + }) + raise FetchError("Failed to fetch page", url, None, { + "error": str(e), + "elapsed_seconds": time.time() - start_time + }) -def parse_html(html_content: Optional[str]) -> str: - """Parse HTML content and extract text with hyperlinks in markdown format.""" - if not html_content: - return "" +def parse_html(content: str, min_text_length: int = 50) -> str: + """ + Parse HTML content and extract text with hyperlinks in markdown format. - try: - document = html5lib.parse(html_content) - result = [] - seen_texts = set() # To avoid duplicates + Args: + content: HTML content to parse + min_text_length: Minimum length for text blocks to be included + + Returns: + str: Extracted text in markdown format + Raises: + ValidationError: If content is empty or invalid + """ + if not content or not content.strip(): + raise ValidationError("Empty content") + + try: + document = html5lib.parse(content) + parsed_text = [] + seen_texts = set() + def should_skip_element(elem) -> bool: - """Check if the element should be skipped.""" - # Skip script and style tags if elem.tag in ['{http://www.w3.org/1999/xhtml}script', - '{http://www.w3.org/1999/xhtml}style']: + '{http://www.w3.org/1999/xhtml}style', + '{http://www.w3.org/1999/xhtml}noscript', + '{http://www.w3.org/1999/xhtml}iframe']: return True - # Skip empty elements or elements with only whitespace if not any(text.strip() for text in elem.itertext()): return True return False - + def process_element(elem, depth=0): - """Process an element and its children recursively.""" if should_skip_element(elem): return - - # Handle text content + if hasattr(elem, 'text') and elem.text: text = elem.text.strip() if text and text not in seen_texts: - # Check if this is an anchor tag if elem.tag == '{http://www.w3.org/1999/xhtml}a': href = None for attr, value in elem.items(): if attr.endswith('href'): href = value break - if href and not href.startswith(('#', 'javascript:')): - # Format as markdown link + if href and not href.startswith(('#', 'javascript:', 'mailto:')): link_text = f"[{text}]({href})" - result.append(" " * depth + link_text) + parsed_text.append(" " * depth + link_text) seen_texts.add(text) else: - result.append(" " * depth + text) - seen_texts.add(text) - - # Process children + if len(text) >= min_text_length: + parsed_text.append(" " * depth + text) + seen_texts.add(text) + for child in elem: process_element(child, depth + 1) - - # Handle tail text + if hasattr(elem, 'tail') and elem.tail: tail = elem.tail.strip() - if tail and tail not in seen_texts: - result.append(" " * depth + tail) + if tail and tail not in seen_texts and len(tail) >= min_text_length: + parsed_text.append(" " * depth + tail) seen_texts.add(tail) - - # Start processing from the body tag + body = document.find('.//{http://www.w3.org/1999/xhtml}body') if body is not None: process_element(body) else: - # Fallback to processing the entire document process_element(document) - - # Filter out common unwanted patterns - filtered_result = [] - for line in result: - # Skip lines that are likely to be noise - if any(pattern in line.lower() for pattern in [ - 'var ', - 'function()', - '.js', - '.css', - 'google-analytics', - 'disqus', - '{', - '}' - ]): - continue - filtered_result.append(line) - - return '\n'.join(filtered_result) + + filtered_text = [ + line for line in parsed_text + if not any(pattern in line.lower() for pattern in [ + 'var ', 'function()', '.js', '.css', + 'google-analytics', 'disqus', '{', '}', + 'cookie', 'privacy policy', 'terms of service' + ]) + ] + + if not filtered_text: + logger.warning("No meaningful content extracted from HTML") + return "" + + return '\n'.join(filtered_text) + except Exception as e: - logger.error(f"Error parsing HTML: {str(e)}") - return "" + logger.error("Error parsing HTML", extra={ + "context": { + "error": str(e) + } + }) + raise ValidationError("Failed to parse HTML", { + "error": str(e) + }) + +async def process_urls(urls: List[str], max_concurrent: int = 5, session: Optional[aiohttp.ClientSession] = None, timeout: int = 30) -> Dict[str, Any]: + """ + Process multiple URLs concurrently. + + Args: + urls: List of URLs to process + max_concurrent: Maximum number of concurrent requests (1-20) + session: Optional aiohttp session to reuse + timeout: Request timeout in seconds + + Returns: + Dict containing: + - results: List of successfully parsed content + - errors: Dict mapping failed URLs to their error messages + + Raises: + ValidationError: If input parameters are invalid + """ + # Validate inputs + try: + max_concurrent = validate_max_concurrent(max_concurrent) + except ValidationError as e: + logger.error("Invalid max_concurrent", extra={ + "context": e.context + }) + raise + + # Filter out invalid URLs + valid_urls = [] + for url in urls: + if validate_url(url): + valid_urls.append(url) + else: + logger.warning("Skipping invalid URL", extra={ + "context": { + "url": url, + "error": "Invalid URL format" + } + }) + + if not valid_urls: + logger.error("No valid URLs provided", extra={ + "context": { + "total_urls": len(urls), + "valid_urls": 0 + } + }) + raise ValidationError("No valid URLs provided", { + "total_urls": len(urls) + }) + + logger.info("Processing URLs", extra={ + "context": { + "total_urls": len(urls), + "valid_urls": len(valid_urls), + "max_concurrent": max_concurrent, + "timeout": timeout + } + }) + + results = [] + errors = {} + + async def process_url(url: str, session: aiohttp.ClientSession) -> None: + try: + logger.debug("Starting URL processing", extra={ + "context": { + "url": url + } + }) + content = await fetch_page(url, session, timeout) + parsed = parse_html(content) + if parsed: + results.append({ + "url": url, + "content": parsed, + "timestamp": time.time() + }) + logger.debug("Successfully processed URL", extra={ + "context": { + "url": url, + "content_length": len(parsed) + } + }) + else: + logger.warning("No content extracted", extra={ + "context": { + "url": url + } + }) + errors[url] = "No meaningful content extracted" + except Exception as e: + logger.error("Failed to process URL", extra={ + "context": { + "url": url, + "error": str(e) + } + }) + errors[url] = str(e) -async def process_urls(urls: List[str], max_concurrent: int = 5, session: Optional[aiohttp.ClientSession] = None) -> List[str]: - """Process multiple URLs concurrently.""" if session is None: async with aiohttp.ClientSession() as session: - tasks = [fetch_page(url, session) for url in urls] - html_contents = await asyncio.gather(*tasks) + tasks = [process_url(url, session) for url in valid_urls] + # Process in batches to respect max_concurrent + for i in range(0, len(tasks), max_concurrent): + batch = tasks[i:i + max_concurrent] + await asyncio.gather(*batch, return_exceptions=True) else: - tasks = [fetch_page(url, session) for url in urls] - html_contents = await asyncio.gather(*tasks) - - # Parse HTML contents in parallel - with Pool() as pool: - results = pool.map(parse_html, html_contents) - - return results + tasks = [process_url(url, session) for url in valid_urls] + for i in range(0, len(tasks), max_concurrent): + batch = tasks[i:i + max_concurrent] + await asyncio.gather(*batch, return_exceptions=True) -def validate_url(url: str) -> bool: - """Validate if a string is a valid URL.""" - try: - result = urlparse(url) - return all([result.scheme, result.netloc]) - except: - return False + logger.info("Completed URL processing", extra={ + "context": { + "total_urls": len(valid_urls), + "successful": len(results), + "failed": len(errors) + } + }) + + return { + "results": results, + "errors": errors + } def main(): - """Main function to process URLs from command line.""" - parser = argparse.ArgumentParser(description='Fetch and process multiple URLs concurrently.') + parser = create_parser('Fetch and process multiple URLs concurrently') parser.add_argument('urls', nargs='+', help='URLs to process') - parser.add_argument('--max-concurrent', type=int, default=5, help='Maximum number of concurrent requests') - args = parser.parse_args() + parser.add_argument('--max-concurrent', type=int, default=5, + help='Maximum number of concurrent requests (1-20)') + parser.add_argument('--timeout', type=int, default=30, + help='Request timeout in seconds') - # Validate URLs - valid_urls = [url for url in args.urls if validate_url(url)] - if not valid_urls: - logger.error("No valid URLs provided") + args = parser.parse_args() + + # Configure logging + log_config = get_log_config(args) + logger = setup_logging(__name__, **log_config) + logger.debug("Debug logging enabled", extra={ + "context": { + "log_level": log_config["level"].value, + "log_format": log_config["format_type"].value + } + }) + + try: + start_time = time.time() + results = asyncio.run(process_urls( + args.urls, + args.max_concurrent, + timeout=args.timeout + )) + elapsed = time.time() - start_time + + metadata = { + "total_urls": len(args.urls), + "successful": len(results["results"]), + "failed": len(results["errors"]), + "elapsed_time": format_duration(elapsed) + } + + print(format_output(results, args.format, "Web Scraping Results", metadata)) + + # Exit with error if any URLs failed + if results["errors"]: + sys.exit(1) + + except ValidationError as e: + logger.error("Invalid input", extra={"context": e.context}) + sys.exit(1) + except Exception as e: + logger.error("Processing failed", extra={ + "context": { + "error": str(e) + } + }) sys.exit(1) - - # Process URLs - results = asyncio.run(process_urls(valid_urls, args.max_concurrent)) - - # Print results - for url, content in zip(valid_urls, results): - print(f"\n=== Content from {url} ===\n") - print(content) -if __name__ == '__main__': +if __name__ == "__main__": main() \ No newline at end of file