diff --git a/scrapegraph-py/CHANGELOG.md b/scrapegraph-py/CHANGELOG.md index 441ecfa..07f3688 100644 --- a/scrapegraph-py/CHANGELOG.md +++ b/scrapegraph-py/CHANGELOG.md @@ -1,3 +1,31 @@ +## [1.6.0-beta.1](https://github.com/ScrapeGraphAI/scrapegraph-sdk/compare/v1.5.0...v1.6.0-beta.1) (2024-12-05) + + +### Features + +* changed SyncClient to Client ([9e1e496](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/9e1e496059cd24810a96b818da1811830586f94b)) + + +### Bug Fixes + +* logger working properly now ([9712d4c](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/9712d4c39eea860f813e86a5e2ffc14db6d3a655)) +* updated env variable loading ([2643f11](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/2643f11c968f0daab26529d513f08c2817763b50)) + + +### CI + +* **release:** 1.4.3-beta.2 [skip ci] ([8ab6147](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/8ab61476b6763b936e2e7d423b04bb51983fb8ea)) +* **release:** 1.4.3-beta.3 [skip ci] ([1bc26c7](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/1bc26c738443f7f52492a7b2cbe7c9f335315797)) +* **release:** 1.5.0-beta.1 [skip ci] ([8900f7b](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/8900f7bf53239b6a73fb41196f5327d05763bae4)) + +## [1.5.0-beta.1](https://github.com/ScrapeGraphAI/scrapegraph-sdk/compare/v1.4.3-beta.3...v1.5.0-beta.1) (2024-12-05) + + +### Features + +* changed SyncClient to Client ([9e1e496](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/9e1e496059cd24810a96b818da1811830586f94b)) + + ## [1.5.0](https://github.com/ScrapeGraphAI/scrapegraph-sdk/compare/v1.4.3...v1.5.0) (2024-12-04) @@ -7,6 +35,19 @@ ## [1.4.3](https://github.com/ScrapeGraphAI/scrapegraph-sdk/compare/v1.4.2...v1.4.3) (2024-12-03) +## [1.4.3-beta.3](https://github.com/ScrapeGraphAI/scrapegraph-sdk/compare/v1.4.3-beta.2...v1.4.3-beta.3) (2024-12-05) + + +### Bug Fixes + +* updated env variable loading ([2643f11](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/2643f11c968f0daab26529d513f08c2817763b50)) + +## [1.4.3-beta.2](https://github.com/ScrapeGraphAI/scrapegraph-sdk/compare/v1.4.3-beta.1...v1.4.3-beta.2) (2024-12-05) + + +### Bug Fixes + +* logger working properly now ([9712d4c](https://github.com/ScrapeGraphAI/scrapegraph-sdk/commit/9712d4c39eea860f813e86a5e2ffc14db6d3a655)) ### Bug Fixes diff --git a/scrapegraph-py/examples/async_smartscraper_example.py b/scrapegraph-py/examples/async_smartscraper_example.py index 3158ae8..ca17f01 100644 --- a/scrapegraph-py/examples/async_smartscraper_example.py +++ b/scrapegraph-py/examples/async_smartscraper_example.py @@ -1,9 +1,9 @@ import asyncio from scrapegraph_py import AsyncClient -from scrapegraph_py.logger import get_logger +from scrapegraph_py.logger import sgai_logger -get_logger(level="DEBUG") +sgai_logger.set_logging(level="INFO") async def main(): diff --git a/scrapegraph-py/examples/smartscraper_schema_async_example.py b/scrapegraph-py/examples/async_smartscraper_schema_example.py similarity index 99% rename from scrapegraph-py/examples/smartscraper_schema_async_example.py rename to scrapegraph-py/examples/async_smartscraper_schema_example.py index a1a3e23..d7cd4fa 100644 --- a/scrapegraph-py/examples/smartscraper_schema_async_example.py +++ b/scrapegraph-py/examples/async_smartscraper_schema_example.py @@ -1,5 +1,7 @@ import asyncio + from pydantic import BaseModel, Field + from scrapegraph_py import AsyncClient @@ -27,5 +29,6 @@ async def main(): await sgai_client.close() + if __name__ == "__main__": asyncio.run(main()) diff --git a/scrapegraph-py/examples/feedback_example.py b/scrapegraph-py/examples/feedback_example.py index 5d14613..dc20ae2 100644 --- a/scrapegraph-py/examples/feedback_example.py +++ b/scrapegraph-py/examples/feedback_example.py @@ -1,24 +1,28 @@ -from scrapegraph_py import SyncClient -from scrapegraph_py.logger import get_logger - -get_logger(level="DEBUG") - -# Initialize the client -sgai_client = SyncClient(api_key="your-api-key-here") - -# Example request_id (replace with an actual request_id from a previous request) -request_id = "your-request-id-here" - -# Submit feedback for a previous request -feedback_response = sgai_client.submit_feedback( - request_id=request_id, - rating=5, # Rating from 1-5 - feedback_text="The extraction was accurate and exactly what I needed!", -) -print(f"\nFeedback Response: {feedback_response}") - -# Get previous results using get_smartscraper -previous_result = sgai_client.get_smartscraper(request_id=request_id) -print(f"\nRetrieved Previous Result: {previous_result}") - -sgai_client.close() +from scrapegraph_py import Client +from scrapegraph_py.logger import sgai_logger + +sgai_logger.set_logging(level="INFO") + +# Initialize the client +sgai_client = Client(api_key="your-api-key-here") + +# Example request_id (replace with an actual request_id from a previous request) +request_id = "your-request-id-here" + +# Check remaining credits +credits = sgai_client.get_credits() +print(f"Credits Info: {credits}") + +# Submit feedback for a previous request +feedback_response = sgai_client.submit_feedback( + request_id=request_id, + rating=5, # Rating from 1-5 + feedback_text="The extraction was accurate and exactly what I needed!", +) +print(f"\nFeedback Response: {feedback_response}") + +# Get previous results using get_smartscraper +previous_result = sgai_client.get_smartscraper(request_id=request_id) +print(f"\nRetrieved Previous Result: {previous_result}") + +sgai_client.close() diff --git a/scrapegraph-py/examples/get_credits_example.py b/scrapegraph-py/examples/get_credits_example.py index c62bc06..6ef9e2f 100644 --- a/scrapegraph-py/examples/get_credits_example.py +++ b/scrapegraph-py/examples/get_credits_example.py @@ -1,10 +1,10 @@ -from scrapegraph_py import SyncClient -from scrapegraph_py.logger import get_logger +from scrapegraph_py import Client +from scrapegraph_py.logger import sgai_logger -get_logger(level="DEBUG") +sgai_logger.set_level("DEBUG") # Initialize the client -sgai_client = SyncClient(api_key="your-api-key-here") +sgai_client = Client(api_key="your-api-key-here") # Check remaining credits credits = sgai_client.get_credits() diff --git a/scrapegraph-py/examples/smartscraper_example.py b/scrapegraph-py/examples/smartscraper_example.py index 5d44348..37e4542 100644 --- a/scrapegraph-py/examples/smartscraper_example.py +++ b/scrapegraph-py/examples/smartscraper_example.py @@ -1,10 +1,10 @@ -from scrapegraph_py import SyncClient -from scrapegraph_py.logger import get_logger +from scrapegraph_py import Client +from scrapegraph_py.logger import sgai_logger -get_logger(level="DEBUG") +sgai_logger.set_logging(level="INFO") -# Initialize the client -sgai_client = SyncClient(api_key="your-api-key-here") +# Initialize the client with explicit API key +sgai_client = Client(api_key="your-api-key-here") # SmartScraper request response = sgai_client.smartscraper( diff --git a/scrapegraph-py/examples/smartscraper_schema_example.py b/scrapegraph-py/examples/smartscraper_schema_example.py index fcc17c0..3553a22 100644 --- a/scrapegraph-py/examples/smartscraper_schema_example.py +++ b/scrapegraph-py/examples/smartscraper_schema_example.py @@ -1,6 +1,6 @@ from pydantic import BaseModel, Field -from scrapegraph_py import SyncClient +from scrapegraph_py import Client # Define a Pydantic model for the output schema @@ -11,7 +11,7 @@ class WebpageSchema(BaseModel): # Initialize the client -sgai_client = SyncClient(api_key="your-api-key-here") +sgai_client = Client(api_key="your-api-key-here") # SmartScraper request with output schema response = sgai_client.smartscraper( diff --git a/scrapegraph-py/scrapegraph_py/__init__.py b/scrapegraph-py/scrapegraph_py/__init__.py index 4a07e84..7ab2178 100644 --- a/scrapegraph-py/scrapegraph_py/__init__.py +++ b/scrapegraph-py/scrapegraph_py/__init__.py @@ -1,4 +1,4 @@ from .async_client import AsyncClient -from .client import SyncClient +from .client import Client -__all__ = ["SyncClient", "AsyncClient"] +__all__ = ["Client", "AsyncClient"] diff --git a/scrapegraph-py/scrapegraph_py/async_client.py b/scrapegraph-py/scrapegraph_py/async_client.py index 2246736..e0bd2a5 100644 --- a/scrapegraph-py/scrapegraph_py/async_client.py +++ b/scrapegraph-py/scrapegraph_py/async_client.py @@ -49,7 +49,7 @@ def from_env( def __init__( self, - api_key: str, + api_key: str = None, verify_ssl: bool = True, timeout: float = 120, max_retries: int = 3, @@ -58,13 +58,24 @@ def __init__( """Initialize AsyncClient with configurable parameters. Args: - api_key: API key for authentication + api_key: API key for authentication. If None, will try to load from environment verify_ssl: Whether to verify SSL certificates timeout: Request timeout in seconds max_retries: Maximum number of retry attempts retry_delay: Delay between retries in seconds """ logger.info("🔑 Initializing AsyncClient") + + # Try to get API key from environment if not provided + if api_key is None: + from os import getenv + + api_key = getenv("SGAI_API_KEY") + if not api_key: + raise ValueError( + "SGAI_API_KEY not provided and not found in environment" + ) + validate_api_key(api_key) logger.debug( f"🛠️ Configuration: verify_ssl={verify_ssl}, timeout={timeout}, max_retries={max_retries}" diff --git a/scrapegraph-py/scrapegraph_py/client.py b/scrapegraph-py/scrapegraph_py/client.py index 30ce15c..df4c133 100644 --- a/scrapegraph-py/scrapegraph_py/client.py +++ b/scrapegraph-py/scrapegraph_py/client.py @@ -1,4 +1,4 @@ -# Sync client implementation goes here +# Client implementation goes here from typing import Any, Optional import requests @@ -17,7 +17,7 @@ from scrapegraph_py.utils.helpers import handle_sync_response, validate_api_key -class SyncClient: +class Client: @classmethod def from_env( cls, @@ -26,7 +26,7 @@ def from_env( max_retries: int = 3, retry_delay: float = 1.0, ): - """Initialize SyncClient using API key from environment variable. + """Initialize Client using API key from environment variable. Args: verify_ssl: Whether to verify SSL certificates @@ -35,6 +35,7 @@ def from_env( retry_delay: Delay between retries in seconds """ from os import getenv + api_key = getenv("SGAI_API_KEY") if not api_key: raise ValueError("SGAI_API_KEY environment variable not set") @@ -48,22 +49,33 @@ def from_env( def __init__( self, - api_key: str, + api_key: str = None, verify_ssl: bool = True, timeout: float = 120, max_retries: int = 3, retry_delay: float = 1.0, ): - """Initialize SyncClient with configurable parameters. + """Initialize Client with configurable parameters. Args: - api_key: API key for authentication + api_key: API key for authentication. If None, will try to load from environment verify_ssl: Whether to verify SSL certificates timeout: Request timeout in seconds max_retries: Maximum number of retry attempts retry_delay: Delay between retries in seconds """ - logger.info("🔑 Initializing SyncClient") + logger.info("🔑 Initializing Client") + + # Try to get API key from environment if not provided + if api_key is None: + from os import getenv + + api_key = getenv("SGAI_API_KEY") + if not api_key: + raise ValueError( + "SGAI_API_KEY not provided and not found in environment" + ) + validate_api_key(api_key) logger.debug( f"🛠️ Configuration: verify_ssl={verify_ssl}, timeout={timeout}, max_retries={max_retries}" @@ -95,7 +107,7 @@ def __init__( if not verify_ssl: urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - logger.info("✅ SyncClient initialized successfully") + logger.info("✅ Client initialized successfully") def _make_request(self, method: str, url: str, **kwargs) -> Any: """Make HTTP request with error handling.""" @@ -199,7 +211,7 @@ def submit_feedback( def close(self): """Close the session to free up resources""" - logger.info("🔒 Closing SyncClient session") + logger.info("🔒 Closing Client session") self.session.close() logger.debug("✅ Session closed successfully") diff --git a/scrapegraph-py/scrapegraph_py/logger.py b/scrapegraph-py/scrapegraph_py/logger.py index a1a6ba2..bd2e4dc 100644 --- a/scrapegraph-py/scrapegraph_py/logger.py +++ b/scrapegraph-py/scrapegraph_py/logger.py @@ -22,53 +22,98 @@ def format(self, record: logging.LogRecord) -> str: return super().format(record) -def get_logger( - name: str = "scrapegraph", - level: str = "INFO", - log_file: Optional[str] = None, - log_format: Optional[str] = None, -) -> logging.Logger: - """ - Get a configured logger instance with emoji support. - - Args: - name: Name of the logger (default: 'scrapegraph') - level: Logging level (default: 'INFO') - log_file: Optional file path to write logs to - log_format: Optional custom log format string - - Returns: - logging.Logger: Configured logger instance - """ - logger = logging.getLogger(name) - - # Return existing logger if already configured - if logger.handlers: - return logger - - # Set log level - level = getattr(logging, level.upper(), logging.INFO) - logger.setLevel(level) - - # Default format if none provided - if not log_format: - log_format = "%(levelname)-6s %(asctime)-15s %(message)s" - - formatter = EmojiFormatter(log_format) - - # Console handler - console_handler = logging.StreamHandler() - console_handler.setFormatter(formatter) - logger.addHandler(console_handler) - - # File handler if log_file specified - if log_file: - file_handler = logging.FileHandler(log_file) - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) - - return logger - - -# Default sgai logger instance -sgai_logger = get_logger() +class ScrapegraphLogger: + """Class to manage Scrapegraph logging configuration""" + + _instance = None + _initialized = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super(ScrapegraphLogger, cls).__new__(cls) + return cls._instance + + def __init__(self): + if not self._initialized: + self.logger = logging.getLogger("scrapegraph") + self.logger.setLevel(logging.INFO) + self.enabled = False + self._initialized = True + + def set_logging( + self, + level: Optional[str] = None, + log_file: Optional[str] = None, + log_format: Optional[str] = None, + ) -> None: + """ + Configure logging settings. If level is None, logging will be disabled. + + Args: + level: Logging level (e.g., 'DEBUG', 'INFO'). None to disable logging. + log_file: Optional file path to write logs to + log_format: Optional custom log format string + """ + # Clear existing handlers + self.logger.handlers.clear() + + if level is None: + # Disable logging + self.enabled = False + return + + # Enable logging with specified level + self.enabled = True + level = getattr(logging, level.upper(), logging.INFO) + self.logger.setLevel(level) + + # Default format if none provided + if not log_format: + log_format = "%(emoji)s %(asctime)-15s %(message)s" + + formatter = EmojiFormatter(log_format) + + # Console handler + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + self.logger.addHandler(console_handler) + + # File handler if log_file specified + if log_file: + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter(formatter) + self.logger.addHandler(file_handler) + + def disable(self) -> None: + """Disable all logging""" + self.logger.handlers.clear() + self.enabled = False + + def debug(self, message: str) -> None: + """Log debug message if logging is enabled""" + if self.enabled: + self.logger.debug(message) + + def info(self, message: str) -> None: + """Log info message if logging is enabled""" + if self.enabled: + self.logger.info(message) + + def warning(self, message: str) -> None: + """Log warning message if logging is enabled""" + if self.enabled: + self.logger.warning(message) + + def error(self, message: str) -> None: + """Log error message if logging is enabled""" + if self.enabled: + self.logger.error(message) + + def critical(self, message: str) -> None: + """Log critical message if logging is enabled""" + if self.enabled: + self.logger.critical(message) + + +# Default logger instance +sgai_logger = ScrapegraphLogger() diff --git a/scrapegraph-py/tests/test_async_client.py b/scrapegraph-py/tests/test_async_client.py index e69de29..8ebb43b 100644 --- a/scrapegraph-py/tests/test_async_client.py +++ b/scrapegraph-py/tests/test_async_client.py @@ -0,0 +1,103 @@ +from uuid import uuid4 + +import pytest +from aioresponses import aioresponses + +from scrapegraph_py.async_client import AsyncClient +from scrapegraph_py.exceptions import APIError +from tests.utils import generate_mock_api_key + + +@pytest.fixture +def mock_api_key(): + return generate_mock_api_key() + + +@pytest.fixture +def mock_uuid(): + return str(uuid4()) + + +@pytest.mark.asyncio +async def test_smartscraper(mock_api_key): + with aioresponses() as mocked: + mocked.post( + "https://api.scrapegraphai.com/v1/smartscraper", + payload={ + "request_id": str(uuid4()), + "status": "completed", + "result": {"description": "Example domain."}, + }, + ) + + async with AsyncClient(api_key=mock_api_key) as client: + response = await client.smartscraper( + website_url="https://example.com", user_prompt="Describe this page." + ) + assert response["status"] == "completed" + assert "description" in response["result"] + + +@pytest.mark.asyncio +async def test_get_credits(mock_api_key): + with aioresponses() as mocked: + mocked.get( + "https://api.scrapegraphai.com/v1/credits", + payload={"remaining_credits": 100, "total_credits_used": 50}, + ) + + async with AsyncClient(api_key=mock_api_key) as client: + response = await client.get_credits() + assert response["remaining_credits"] == 100 + assert response["total_credits_used"] == 50 + + +@pytest.mark.asyncio +async def test_submit_feedback(mock_api_key): + with aioresponses() as mocked: + mocked.post( + "https://api.scrapegraphai.com/v1/feedback", payload={"status": "success"} + ) + + async with AsyncClient(api_key=mock_api_key) as client: + response = await client.submit_feedback( + request_id=str(uuid4()), rating=5, feedback_text="Great service!" + ) + assert response["status"] == "success" + + +@pytest.mark.asyncio +async def test_get_smartscraper(mock_api_key, mock_uuid): + with aioresponses() as mocked: + mocked.get( + f"https://api.scrapegraphai.com/v1/smartscraper/{mock_uuid}", + payload={ + "request_id": mock_uuid, + "status": "completed", + "result": {"data": "test"}, + }, + ) + + async with AsyncClient(api_key=mock_api_key) as client: + response = await client.get_smartscraper(mock_uuid) + assert response["status"] == "completed" + assert response["request_id"] == mock_uuid + + +@pytest.mark.asyncio +async def test_api_error(mock_api_key): + with aioresponses() as mocked: + mocked.post( + "https://api.scrapegraphai.com/v1/smartscraper", + status=400, + payload={"error": "Bad request"}, + exception=APIError("Bad request", status_code=400), + ) + + async with AsyncClient(api_key=mock_api_key) as client: + with pytest.raises(APIError) as exc_info: + await client.smartscraper( + website_url="https://example.com", user_prompt="Describe this page." + ) + assert exc_info.value.status_code == 400 + assert "Bad request" in str(exc_info.value) diff --git a/scrapegraph-py/tests/test_client.py b/scrapegraph-py/tests/test_client.py index 0adea06..6163a6d 100644 --- a/scrapegraph-py/tests/test_client.py +++ b/scrapegraph-py/tests/test_client.py @@ -1,44 +1,99 @@ +from uuid import uuid4 + import pytest -from unittest.mock import patch -from scrapegraph_py import SyncClient - -def test_smartscraper(): - # Mock response data - mock_response = { - "request_id": "test-123", - "result": { - "heading": "Example Domain", - "description": "This is a sample description", - "summary": "A test webpage summary" - } - } - - # Create client instance with dummy API key - client = SyncClient(api_key="test-api-key") - - # Mock the API call - with patch.object(client, '_make_request') as mock_request: - # Configure mock to return our test data - mock_request.return_value = mock_response - - # Make the smartscraper request +import responses + +from scrapegraph_py.client import Client +from tests.utils import generate_mock_api_key + + +@pytest.fixture +def mock_api_key(): + return generate_mock_api_key() + + +@pytest.fixture +def mock_uuid(): + return str(uuid4()) + + +@responses.activate +def test_smartscraper(mock_api_key): + # Mock the API response + responses.add( + responses.POST, + "https://api.scrapegraphai.com/v1/smartscraper", + json={ + "request_id": str(uuid4()), + "status": "completed", + "result": {"description": "Example domain."}, + }, + ) + + with Client(api_key=mock_api_key) as client: response = client.smartscraper( - website_url="https://example.com", - user_prompt="Extract the main heading, description, and summary of the webpage" + website_url="https://example.com", user_prompt="Describe this page." ) + assert response["status"] == "completed" + + +@responses.activate +def test_get_smartscraper(mock_api_key, mock_uuid): + responses.add( + responses.GET, + f"https://api.scrapegraphai.com/v1/smartscraper/{mock_uuid}", + json={ + "request_id": mock_uuid, + "status": "completed", + "result": {"data": "test"}, + }, + ) + + with Client(api_key=mock_api_key) as client: + response = client.get_smartscraper(mock_uuid) + assert response["status"] == "completed" + assert response["request_id"] == mock_uuid + + +@responses.activate +def test_get_credits(mock_api_key): + responses.add( + responses.GET, + "https://api.scrapegraphai.com/v1/credits", + json={"remaining_credits": 100, "total_credits_used": 50}, + ) + + with Client(api_key=mock_api_key) as client: + response = client.get_credits() + assert response["remaining_credits"] == 100 + assert response["total_credits_used"] == 50 + + +@responses.activate +def test_submit_feedback(mock_api_key): + responses.add( + responses.POST, + "https://api.scrapegraphai.com/v1/feedback", + json={"status": "success"}, + ) + + with Client(api_key=mock_api_key) as client: + response = client.submit_feedback( + request_id=str(uuid4()), rating=5, feedback_text="Great service!" + ) + assert response["status"] == "success" + + +@responses.activate +def test_network_error(mock_api_key): + responses.add( + responses.POST, + "https://api.scrapegraphai.com/v1/smartscraper", + body=ConnectionError("Network error"), + ) - # Verify the request was made with correct parameters - mock_request.assert_called_once() - call_args = mock_request.call_args[0][0] - assert call_args['method'] == 'POST' - assert 'smartscraper' in call_args['url'] - assert call_args['json']['website_url'] == "https://example.com" - assert call_args['json']['user_prompt'] == "Extract the main heading, description, and summary of the webpage" - - # Verify response structure and content - assert isinstance(response, dict) - assert response['request_id'] == "test-123" - assert isinstance(response['result'], dict) - - # Clean up - client.close() + with Client(api_key=mock_api_key) as client: + with pytest.raises(ConnectionError): + client.smartscraper( + website_url="https://example.com", user_prompt="Describe this page." + ) diff --git a/scrapegraph-py/tests/test_exceptions.py b/scrapegraph-py/tests/test_exceptions.py new file mode 100644 index 0000000..8d19815 --- /dev/null +++ b/scrapegraph-py/tests/test_exceptions.py @@ -0,0 +1,15 @@ +from scrapegraph_py.exceptions import APIError + + +def test_api_error(): + error = APIError("Test error", status_code=400) + assert str(error) == "[400] Test error" + assert error.status_code == 400 + assert error.message == "Test error" + + +def test_api_error_without_status(): + error = APIError("Test error") + assert str(error) == "[None] Test error" + assert error.status_code is None + assert error.message == "Test error" diff --git a/scrapegraph-py/tests/test_models.py b/scrapegraph-py/tests/test_models.py new file mode 100644 index 0000000..3e75169 --- /dev/null +++ b/scrapegraph-py/tests/test_models.py @@ -0,0 +1,82 @@ +import pytest +from pydantic import BaseModel, ValidationError + +from scrapegraph_py.models.feedback import FeedbackRequest +from scrapegraph_py.models.smartscraper import ( + GetSmartScraperRequest, + SmartScraperRequest, +) + + +def test_smartscraper_request_validation(): + + class ExampleSchema(BaseModel): + name: str + age: int + + # Valid input + request = SmartScraperRequest( + website_url="https://example.com", user_prompt="Describe this page." + ) + assert request.website_url == "https://example.com" + assert request.user_prompt == "Describe this page." + + # Test with output_schema + request = SmartScraperRequest( + website_url="https://example.com", + user_prompt="Describe this page.", + output_schema=ExampleSchema, + ) + + # When we dump the model, the output_schema should be converted to a dict + dumped = request.model_dump() + assert isinstance(dumped["output_schema"], dict) + assert "properties" in dumped["output_schema"] + assert "name" in dumped["output_schema"]["properties"] + assert "age" in dumped["output_schema"]["properties"] + + # Invalid URL + with pytest.raises(ValidationError): + SmartScraperRequest( + website_url="invalid-url", user_prompt="Describe this page." + ) + + # Empty prompt + with pytest.raises(ValidationError): + SmartScraperRequest(website_url="https://example.com", user_prompt="") + + +def test_get_smartscraper_request_validation(): + # Valid UUID + request = GetSmartScraperRequest(request_id="123e4567-e89b-12d3-a456-426614174000") + assert request.request_id == "123e4567-e89b-12d3-a456-426614174000" + + # Invalid UUID + with pytest.raises(ValidationError): + GetSmartScraperRequest(request_id="invalid-uuid") + + +def test_feedback_request_validation(): + # Valid input + request = FeedbackRequest( + request_id="123e4567-e89b-12d3-a456-426614174000", + rating=5, + feedback_text="Great service!", + ) + assert request.request_id == "123e4567-e89b-12d3-a456-426614174000" + assert request.rating == 5 + assert request.feedback_text == "Great service!" + + # Invalid rating + with pytest.raises(ValidationError): + FeedbackRequest( + request_id="123e4567-e89b-12d3-a456-426614174000", + rating=6, + feedback_text="Great service!", + ) + + # Invalid UUID + with pytest.raises(ValidationError): + FeedbackRequest( + request_id="invalid-uuid", rating=5, feedback_text="Great service!" + ) diff --git a/scrapegraph-py/tests/utils.py b/scrapegraph-py/tests/utils.py new file mode 100644 index 0000000..194ceef --- /dev/null +++ b/scrapegraph-py/tests/utils.py @@ -0,0 +1,6 @@ +from uuid import uuid4 + + +def generate_mock_api_key(): + """Generate a valid mock API key in the format 'sgai-{uuid}'""" + return f"sgai-{uuid4()}"