diff --git a/nodestream/pipeline/argument_resolvers/__init__.py b/nodestream/pipeline/argument_resolvers/__init__.py index 32b4c841d..ba63e0a39 100644 --- a/nodestream/pipeline/argument_resolvers/__init__.py +++ b/nodestream/pipeline/argument_resolvers/__init__.py @@ -1,4 +1,5 @@ from .argument_resolver import ARGUMENT_RESOLVER_REGISTRY, ArgumentResolver +from .aws_secret_resolver import AWSSecretResolver from .configuration_argument_resolver import ( ConfigurationArgumentResolver, get_config, @@ -12,6 +13,7 @@ "ArgumentResolver", "EnvironmentResolver", "IncludeFileResolver", + "AWSSecretResolver", "ConfigurationArgumentResolver", "get_config", "set_config", diff --git a/nodestream/pipeline/argument_resolvers/aws_secret_resolver.py b/nodestream/pipeline/argument_resolvers/aws_secret_resolver.py new file mode 100644 index 000000000..bd6efea19 --- /dev/null +++ b/nodestream/pipeline/argument_resolvers/aws_secret_resolver.py @@ -0,0 +1,411 @@ +from __future__ import annotations + +import json +import logging +import os +import threading +import time +from dataclasses import dataclass +from functools import wraps +from typing import Any, Callable, Dict, Optional, TypeVar, cast + +import boto3 +from botocore.exceptions import ClientError + +from nodestream.pipeline.argument_resolvers import ArgumentResolver + +# Type variables for decorators +T = TypeVar("T") +F = TypeVar("F", bound=Callable[..., Any]) + +# Configure structured logging +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class SecretResolverConfig: + """Configuration for the SecretResolver. + + Attributes: + cache_ttl: Time-to-live for cache entries in seconds + max_retries: Maximum number of retries for AWS API calls + retry_delay: Delay between retries in seconds + region_name: AWS region name + log_level: Logging level + """ + + cache_ttl: int = 300 # 5 minutes + max_retries: int = 3 + retry_delay: float = 1.0 + # todo get region from environment variable or config + region_name: str = "us-west-2" + log_level: str = "INFO" + + +class SecretResolverError(Exception): + """Base exception for SecretResolver errors.""" + + pass + + +class SecretNotFoundError(SecretResolverError): + """Raised when a secret is not found in AWS Secrets Manager.""" + + pass + + +class SecretDecodeError(SecretResolverError): + """Raised when there is an error decoding a secret.""" + + pass + + +class SecretCacheError(SecretResolverError): + """Raised when there is an error with the secret cache.""" + + pass + + +def retry_on_error(max_retries: int = 3, delay: float = 1.0) -> Callable[[F], F]: + """Decorator to retry a function on failure. + + Args: + max_retries: Maximum number of retries + delay: Delay between retries in seconds + + Returns: + Decorated function that will retry on failure + + Example: + @retry_on_error(max_retries=3, delay=1.0) + def my_function(): + # Function that may fail + pass + """ + + def decorator(func: F) -> F: + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + last_exception = None + for attempt in range(max_retries): + try: + return func(*args, **kwargs) + except Exception as e: + last_exception = e + if attempt < max_retries - 1: + msg = ( + f"Attempt {attempt + 1} failed for {func.__name__}, " + f"retrying in {delay} seconds... Error: {str(e)}" + ) + logger.warning(msg) + time.sleep(delay) + raise last_exception or Exception("Unknown error occurred") + + return cast(F, wrapper) + + return decorator + + +class SecretCache: + """Thread-safe cache for secrets with TTL. + + This class implements a thread-safe cache with time-to-live (TTL) for + storing and retrieving secrets. It uses a lock to ensure thread safety + and automatically removes expired entries. + + Attributes: + _ttl: Time-to-live for cache entries in seconds + _lock: Thread lock for thread safety + _cache: Dictionary storing cache entries with expiry timestamps + """ + + def __init__(self, ttl: int = 300) -> None: + """Initialize the secret cache. + + Args: + ttl: Time-to-live for cache entries in seconds + """ + self._ttl = ttl + self._lock = threading.Lock() + self._cache: Dict[str, tuple[Any, float]] = {} + + def get(self, key: str) -> Optional[Any]: + """Get a value from the cache if it exists and is not expired. + + Args: + key: Cache key + + Returns: + Cached value if it exists and is not expired, None otherwise + """ + with self._lock: + if key in self._cache: + value, expiry = self._cache[key] + if time.time() < expiry: + logger.debug(f"Cache HIT: {key}") + return value + logger.debug(f"Cache EXPIRED: {key}") + del self._cache[key] + logger.debug(f"Cache MISS: {key}") + return None + + def set(self, key: str, value: Any) -> None: + """Set a value in the cache with TTL. + + Args: + key: Cache key + value: Value to cache + """ + with self._lock: + self._cache[key] = (value, time.time() + self._ttl) + msg = f"Cache SET: {key} (expires in {self._ttl} seconds)" + logger.debug(msg) + + +# Initialize caches lazily +_secret_cache: Optional[SecretCache] = None +_json_cache: Optional[SecretCache] = None + + +def _get_secret_cache() -> SecretCache: + """Get the secret cache, initializing it if necessary.""" + global _secret_cache + if _secret_cache is None: + _secret_cache = SecretCache() + return _secret_cache + + +def _get_json_cache() -> SecretCache: + """Get the JSON cache, initializing it if necessary.""" + global _json_cache + if _json_cache is None: + _json_cache = SecretCache() + return _json_cache + + +class AWSSecretResolver(ArgumentResolver, alias="aws-secret"): # type: ignore[call-arg] + """AWS Secrets Manager argument resolver for Nodestream with caching and retries. + + This resolver fetches secrets from AWS Secrets Manager and caches them for + performance. It supports both string secrets and JSON secrets with specific + key extraction. It implements a singleton pattern to ensure a single instance + is used throughout the application. + + Example usage in nodestream.yaml: + password: + resolver: aws-secret + variable: NEO4J_PASSWORD.password # For JSON secrets + # OR + variable: NEO4J_PASSWORD # For string secrets + + Attributes: + _instance: Singleton instance of the resolver + config: Configuration for the resolver + _session: AWS session + _client: AWS Secrets Manager client + """ + + _instance: Optional[AWSSecretResolver] = None + + def __new__( + cls, config: Optional[SecretResolverConfig] = None + ) -> AWSSecretResolver: + """Ensure singleton instance. + + Args: + config: Optional configuration for the resolver + + Returns: + Singleton instance of SecretResolver + """ + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance.config = config or SecretResolverConfig() + cls._instance._session = boto3.session.Session() + cls._instance._client = cls._instance._session.client( + service_name="secretsmanager", + region_name=cls._instance.config.region_name, + ) + return cls._instance + + def __init__(self, config: Optional[SecretResolverConfig] = None) -> None: + """Initialize the SecretResolver. + + Args: + config: Optional configuration for the resolver + """ + # Skip initialization if instance already exists + if hasattr(self, "config"): + return + + self.config = config or SecretResolverConfig() + self._session = boto3.session.Session() + self._client = self._session.client( + service_name="secretsmanager", region_name=self.config.region_name + ) + # Initialize instance attributes to prevent AttributeError + self._secret_value: Optional[str] = None + self.default: Optional[str] = None + + @staticmethod + def _get_secret_name_from_env(env_var: str) -> Optional[str]: + """Get secret name from environment variable. + + Args: + env_var: Environment variable name + + Returns: + Secret name if environment variable exists and is not empty, None otherwise + """ + secret_name = os.environ.get(env_var) + if not secret_name: + logger.warning(f"Environment variable '{env_var}' is not set or is empty") + return None + return secret_name + + @retry_on_error() + def _get_secret_from_aws(self, secret_name: str) -> str: + """Fetch a secret from AWS Secrets Manager. + + Args: + secret_name: Name of the secret to fetch + + Returns: + Secret value as string + + Raises: + SecretNotFoundError: If the secret is not found + SecretDecodeError: If the secret cannot be decoded + """ + try: + response = self._client.get_secret_value(SecretId=secret_name) + if "SecretString" in response: + return str(response["SecretString"]) + raise SecretDecodeError( + f"Secret '{secret_name}' is binary, which is not supported" + ) + except ClientError as e: + if e.response["Error"]["Code"] == "ResourceNotFoundException": + raise SecretNotFoundError(f"Secret '{secret_name}' not found") from e + raise SecretResolverError( + f"Error fetching secret '{secret_name}': {e}" + ) from e + + def _resolve_string_secret(self, secret_name: str) -> Optional[str]: + """Resolve a string secret with caching. + + Args: + secret_name: Name of the secret to resolve + + Returns: + Secret value if successful, None otherwise + """ + logger.info(f"Resolving string secret '{secret_name}'") + + # Try cache first + cached_secret = _get_secret_cache().get(secret_name) + if cached_secret is not None: + return cached_secret # type: ignore[no-any-return] + + try: + # Fetch and cache + secret_value = self._get_secret_from_aws(secret_name) + _get_secret_cache().set(secret_name, secret_value) + logger.info(f"Cached string secret '{secret_name}'") + return secret_value + except SecretResolverError as e: + logger.error(f"Error resolving string secret '{secret_name}': {e}") + return None + + def _resolve_json_secret(self, secret_name: str, json_key: str) -> Optional[Any]: + """Resolve a JSON secret with caching. + + Args: + secret_name: Name of the secret to resolve + json_key: Key to extract from the JSON secret + + Returns: + JSON value if successful, None otherwise + """ + logger.info(f"Resolving JSON secret '{secret_name}' with key '{json_key}'") + + cache_key = f"{secret_name}:{json_key}" + + # Try JSON cache first + cached_json = _get_json_cache().get(cache_key) + if cached_json is not None: + return cached_json + + # Get the secret string - this may raise SecretResolverError + secret_json_string = self._get_secret_from_aws(secret_name) + if not secret_json_string: + logger.error(f"Failed to get secret string for '{secret_name}'") + return None + + # Parse JSON - this may raise SecretDecodeError + try: + secret_data = json.loads(secret_json_string) + except json.JSONDecodeError as e: + logger.error(f"Secret '{secret_name}' is not valid JSON: {e}") + return None + + # Extract and cache the JSON value + if json_key not in secret_data: + logger.error(f"Key '{json_key}' not found in secret '{secret_name}'") + return None + + _get_json_cache().set(cache_key, secret_data[json_key]) + logger.info(f"Cached JSON key '{json_key}' from secret '{secret_name}'") + return secret_data[json_key] + + @classmethod + def resolve_argument(cls, variable_name: str) -> Optional[Any]: + """Resolve an argument by fetching it from AWS Secrets Manager with caching. + + This method is called by the nodestream plugin system to resolve arguments + that use the !aws-secret tag in the configuration. + + Supports two formats: + 1. 'ENV_VAR_NAME.json_key' - For JSON secrets, returns the specific JSON key value + 2. 'ENV_VAR_NAME' - For string secrets, returns the entire secret value + + Args: + variable_name: The variable name in either format: + - 'ENV_VAR_NAME.json_key' for JSON secrets + - 'ENV_VAR_NAME' for string secrets + + Returns: + The resolved value from the secret, or None if resolution failed + + Example: + In nodestream.yaml: + password: !aws-secret NEO4J_PASSWORD.password + # OR + password: !aws-secret NEO4J_PASSWORD + """ + instance = cls() # Get singleton instance + try: + # Split the variable name into parts + parts = variable_name.split(".", 1) + env_var_part = parts[0] + json_key_part = parts[1] if len(parts) > 1 else None + + # Get secret name from environment variable + secret_name = instance._get_secret_name_from_env(env_var_part) + if not secret_name: + return None + + # Resolve based on type + if json_key_part is None: + return instance._resolve_string_secret(secret_name) + return instance._resolve_json_secret(secret_name, json_key_part) + + except Exception as e: + logger.error( + f"Unexpected error resolving '{variable_name}': {e}", exc_info=True + ) + return None + + def get_value(self) -> str: + """Return the secret value (or a default if not found).""" + return self._secret_value or self.default # type: ignore[no-any-return] diff --git a/nodestream/pipeline/value_providers/__init__.py b/nodestream/pipeline/value_providers/__init__.py index 282de631d..13c529107 100644 --- a/nodestream/pipeline/value_providers/__init__.py +++ b/nodestream/pipeline/value_providers/__init__.py @@ -7,6 +7,7 @@ from .split_value_provider import SplitValueProvider from .static_value_provider import StaticValueProvider from .string_format_value_provider import StringFormattingValueProvider +from .uuid_value_provider import UuidValueProvider from .value_provider import ( VALUE_PROVIDER_REGISTRY, StaticValueOrValueProvider, @@ -28,4 +29,5 @@ "StaticValueOrValueProvider", "RegexValueProvider", "CastValueProvider", + "UuidValueProvider", ) diff --git a/nodestream/pipeline/value_providers/uuid_value_provider.py b/nodestream/pipeline/value_providers/uuid_value_provider.py new file mode 100644 index 000000000..9d6f485db --- /dev/null +++ b/nodestream/pipeline/value_providers/uuid_value_provider.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +import logging +import uuid +from typing import Any, Type + +from yaml import SafeDumper, SafeLoader + +from .context import ProviderContext +from .value_provider import ValueProvider + +# Configure structured logging +logger = logging.getLogger(__name__) + +# Default namespace for our application +DEFAULT_NAMESPACE = "nodestream" + + +class UuidValueProvider(ValueProvider): + """UUID generator value provider for Nodestream. + + This value provider generates UUIDs on demand and can be used with the !uuid + tag in Nodestream configuration files. + + Supports both simple string input and structured configuration: + + Simple format: + id: !uuid # Random UUID v4 + id: !uuid "finding" # Deterministic UUID v5 based on "finding" + + Structured format: + # Full configuration with both variable_name and namespace + id: !uuid + variable_name: "finding" + namespace: "my-custom-namespace" + + # Only variable_name (uses default namespace "nodestream") + id: !uuid + variable_name: "exposure_finding" + + # Only namespace (generates random UUID v4 with custom namespace) + id: !uuid + namespace: "my-random-namespace" + + # Empty configuration (generates random UUID v4 with default namespace) + id: !uuid + + When a variable_name is provided, it generates a deterministic UUID v5 + based on the namespace and variable_name. When no variable_name is provided, + it generates a random UUID v4. + """ + + __slots__ = ("variable_name", "namespace") + + @classmethod + def install_yaml_tag(cls, loader: Type[SafeLoader]): + loader.add_constructor( + "!uuid", + lambda loader, node: cls.from_yaml_node(loader, node), + ) + + @classmethod + def from_yaml_node(cls, loader: SafeLoader, node): + """Create a UuidValueProvider from YAML node.""" + if node.id == "scalar": + # Simple string format: !uuid "finding" + expression = loader.construct_scalar(node) + return cls.from_string_expression(expression) + elif node.id == "mapping": + # Structured format: !uuid { variable_name: "finding", namespace: "ns" } + data = loader.construct_mapping(node) + return cls.from_structured_data(data) + else: + # Empty format: !uuid + return cls() + + @classmethod + def from_string_expression(cls, expression: str): + """Create from simple string expression.""" + return cls(variable_name=expression) + + @classmethod + def from_structured_data(cls, data: dict): + """Create from structured data dictionary.""" + variable_name = data.get("variable_name", "") + namespace = data.get("namespace", DEFAULT_NAMESPACE) + return cls(variable_name=variable_name, namespace=namespace) + + def __init__(self, variable_name: str = "", namespace: str = DEFAULT_NAMESPACE): + """Initialize the UUID value provider. + + Args: + variable_name: If provided, generates a deterministic UUID v5 + based on this name. If empty, generates a random UUID v4. + namespace: The namespace to use for deterministic UUID generation. + Defaults to "nodestream". + """ + self.variable_name = variable_name.strip() if variable_name else "" + self.namespace = namespace.strip() if namespace else DEFAULT_NAMESPACE + + def single_value(self, context: ProviderContext) -> Any: + """Generate a UUID value. + + Args: + context: The provider context containing record data + + Returns: + A new UUID string + + Example: + In nodestream.yaml: + id: !uuid # Random UUID v4 + id: !uuid "finding" # Deterministic UUID v5 + """ + try: + if self.variable_name: + # Generate namespace UUID from the string + namespace_uuid = uuid.uuid5(uuid.NAMESPACE_DNS, self.namespace) + # Generate deterministic UUID v5 based on variable_name + new_uuid = str(uuid.uuid5(namespace_uuid, self.variable_name)) + logger.debug( + f"Generated deterministic UUID v5: {new_uuid[:8]}... " + f"(namespace: '{self.namespace}', variable: '{self.variable_name}')" + ) + else: + # Generate random UUID v4 + new_uuid = str(uuid.uuid4()) + logger.debug(f"Generated random UUID v4: {new_uuid[:8]}...") + + return new_uuid + except Exception as e: + logger.error(f"Error generating UUID: {e}", exc_info=True) + # Return a fallback UUID in case of error + return str(uuid.uuid4()) + + def many_values(self, context: ProviderContext): + yield self.single_value(context) + + def __str__(self) -> str: + return f"UuidValueProvider: {{'variable_name': '{self.variable_name}', 'namespace': '{self.namespace}'}}" + + +SafeDumper.add_representer( + UuidValueProvider, + lambda dumper, uuid_provider: dumper.represent_scalar( + "!uuid", uuid_provider.variable_name + ), +) diff --git a/tests/unit/pipeline/argument_resolvers/test_aws_secret_resolver.py b/tests/unit/pipeline/argument_resolvers/test_aws_secret_resolver.py new file mode 100644 index 000000000..0d9de5b99 --- /dev/null +++ b/tests/unit/pipeline/argument_resolvers/test_aws_secret_resolver.py @@ -0,0 +1,151 @@ +import time + +import pytest +from hamcrest import assert_that, equal_to + +from nodestream.pipeline.argument_resolvers.aws_secret_resolver import ( + AWSSecretResolver, + SecretCache, +) + + +@pytest.fixture +def mock_boto3_client(mocker): + # Reset the singleton instance + AWSSecretResolver._instance = None + # Patch boto3 session and client + mock_client = mocker.Mock() + mock_session = mocker.patch("boto3.session.Session") + mock_session.return_value.client.return_value = mock_client + return mock_client + + +def test_resolve_string_secret(monkeypatch, mock_boto3_client): + # Set up environment variable + monkeypatch.setenv("FAKE_SECRET_ENV", "fake_secret_name") + # Mock AWS response + mock_boto3_client.get_secret_value.return_value = {"SecretString": "supersecret"} + # Should resolve the string secret + result = AWSSecretResolver.resolve_argument("FAKE_SECRET_ENV") + assert_that(result, equal_to("supersecret")) + + +def test_resolve_json_secret(monkeypatch, mock_boto3_client): + # Set up environment variable + monkeypatch.setenv("FAKE_JSON_SECRET_ENV", "fake_json_secret_name") + # Mock AWS response with JSON string + mock_boto3_client.get_secret_value.return_value = {"SecretString": '{"k": 42}'} + # Should resolve the JSON secret's key + result = AWSSecretResolver.resolve_argument("FAKE_JSON_SECRET_ENV.k") + assert_that(result, equal_to(42)) + + +def test_resolve_json_secret_with_invalid_json_returns_none( + monkeypatch, mock_boto3_client +): + # Set up environment variable + monkeypatch.setenv("FAKE_INVALID_JSON_ENV", "fake_invalid_json_secret_name") + # Mock AWS response with invalid JSON string + mock_boto3_client.get_secret_value.return_value = { + "SecretString": '{"k": 42,}' # Invalid JSON - trailing comma + } + # Should return None when trying to parse invalid JSON + result = AWSSecretResolver.resolve_argument("FAKE_INVALID_JSON_ENV.k") + assert_that(result, equal_to(None)) + + +def test_resolve_json_secret_with_malformed_json_returns_none( + monkeypatch, mock_boto3_client +): + # Set up environment variable + monkeypatch.setenv("FAKE_MALFORMED_JSON_ENV", "fake_malformed_json_secret_name") + # Mock AWS response with malformed JSON string + mock_boto3_client.get_secret_value.return_value = { + "SecretString": '{"k": 42' # Malformed JSON - missing closing brace + } + # Should return None when trying to parse malformed JSON + result = AWSSecretResolver.resolve_argument("FAKE_MALFORMED_JSON_ENV.k") + assert_that(result, equal_to(None)) + + +def test_secret_cache_hit(): + cache = SecretCache(ttl=5) + cache.set("foo", "bar") + assert cache.get("foo") == "bar" + + +def test_secret_cache_expired(monkeypatch): + cache = SecretCache(ttl=1) + cache.set("foo", "bar") + # Simulate time passing beyond TTL + original_time = time.time + monkeypatch.setattr(time, "time", lambda: original_time() + 2) + assert cache.get("foo") is None + # After expired, the key should be removed + assert "foo" not in cache._cache + + +def test_secret_cache_miss(): + cache = SecretCache(ttl=5) + assert cache.get("missing") is None + + +def test_resolve_argument_with_missing_env_var_returns_none( + monkeypatch, mock_boto3_client +): + """Test that resolve_argument returns None when environment variable is not set.""" + # Don't set any environment variable + # Should return None when environment variable is missing + result = AWSSecretResolver.resolve_argument("MISSING_ENV_VAR") + assert_that( + result, + equal_to(None), + ) + + +def test_resolve_argument_with_empty_env_var_returns_none( + monkeypatch, mock_boto3_client +): + """Test that resolve_argument returns None when environment variable is empty.""" + # Set environment variable to empty string + monkeypatch.setenv("EMPTY_ENV_VAR", "") + # Should return None when environment variable is empty + result = AWSSecretResolver.resolve_argument("EMPTY_ENV_VAR") + assert_that( + result, + equal_to(None), + ) + + +def test_resolve_json_argument_with_missing_env_var_returns_none( + monkeypatch, mock_boto3_client +): + """Test that resolve_argument returns None for JSON secrets when env var is missing.""" + # Don't set any environment variable + # Should return None when environment variable is missing + result = AWSSecretResolver.resolve_argument("MISSING_ENV_VAR.key") + assert_that( + result, + equal_to(None), + ) + + +def test_resolve_argument_with_unexpected_exception_returns_none( + monkeypatch, mock_boto3_client +): + """Test that resolve_argument returns None when an unexpected exception occurs.""" + # Clear the cache to ensure we don't get cached results + from nodestream.pipeline.argument_resolvers.aws_secret_resolver import ( + _get_secret_cache, + ) + + _get_secret_cache()._cache.clear() + + # Set up environment variable + monkeypatch.setenv("FAKE_SECRET_ENV", "fake_secret_name") + # Mock AWS client to raise an unexpected exception + mock_boto3_client.get_secret_value.side_effect = Exception("Unexpected error") + + # Should return None when an unexpected exception occurs + result = AWSSecretResolver.resolve_argument("FAKE_SECRET_ENV") + assert_that(result, equal_to(None)) diff --git a/tests/unit/test_uuid_structured.py b/tests/unit/test_uuid_structured.py new file mode 100644 index 000000000..4dc6e3bfb --- /dev/null +++ b/tests/unit/test_uuid_structured.py @@ -0,0 +1,279 @@ +"""Test script for structured UUID value provider.""" + +import uuid +from unittest.mock import patch + +import pytest +import yaml +from hamcrest import assert_that, equal_to, is_not, matches_regexp + +from nodestream.pipeline.value_providers.uuid_value_provider import UuidValueProvider + + +@pytest.fixture +def yaml_loader(): + """Set up YAML loader with UUID tag.""" + UuidValueProvider.install_yaml_tag(yaml.SafeLoader) + return yaml.SafeLoader + + +def test_simple_string_format_parsing(yaml_loader): + """Test that simple string format is parsed correctly.""" + test_yaml = """ + test: + random_uuid: !uuid + deterministic_uuid: !uuid "finding" + """ + + data = yaml.safe_load(test_yaml) + random_provider = data["test"]["random_uuid"] + det_provider = data["test"]["deterministic_uuid"] + + # Test random UUID provider + assert_that(random_provider.variable_name, equal_to("")) + assert_that(random_provider.namespace, equal_to("nodestream")) + + # Test deterministic UUID provider + assert_that(det_provider.variable_name, equal_to("finding")) + assert_that(det_provider.namespace, equal_to("nodestream")) + + +def test_structured_format_with_custom_namespace(yaml_loader): + """Test structured format with custom namespace.""" + test_yaml = """ + test: + custom_uuid: !uuid + variable_name: "finding" + namespace: "my-custom-namespace" + """ + + data = yaml.safe_load(test_yaml) + custom_provider = data["test"]["custom_uuid"] + + assert_that(custom_provider.variable_name, equal_to("finding")) + assert_that(custom_provider.namespace, equal_to("my-custom-namespace")) + + +def test_structured_format_with_default_namespace(yaml_loader): + """Test structured format with only variable_name (uses default namespace).""" + test_yaml = """ + test: + default_ns_uuid: !uuid + variable_name: "finding" + """ + + data = yaml.safe_load(test_yaml) + default_ns_provider = data["test"]["default_ns_uuid"] + + assert_that(default_ns_provider.variable_name, equal_to("finding")) + assert_that(default_ns_provider.namespace, equal_to("nodestream")) + + +def test_structured_format_with_only_namespace(yaml_loader): + """Test structured format with only namespace (random UUID).""" + test_yaml = """ + test: + random_custom_ns: !uuid + namespace: "my-random-namespace" + """ + + data = yaml.safe_load(test_yaml) + random_custom_provider = data["test"]["random_custom_ns"] + + assert_that(random_custom_provider.variable_name, equal_to("")) + assert_that(random_custom_provider.namespace, equal_to("my-random-namespace")) + + +def test_deterministic_uuid_consistency(): + """Test that deterministic UUIDs are consistent for same inputs.""" + provider = UuidValueProvider(variable_name="finding", namespace="test-namespace") + context = {"test": "data"} + + # Same inputs should produce same UUID + uuid1 = provider.single_value(context) + uuid2 = provider.single_value(context) + + assert_that(uuid1, equal_to(uuid2)) + assert_that( + uuid1, + matches_regexp( + r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$" + ), + ) + + +def test_different_namespaces_produce_different_uuids(): + """Test that different namespaces produce different UUIDs.""" + provider1 = UuidValueProvider(variable_name="finding", namespace="namespace1") + provider2 = UuidValueProvider(variable_name="finding", namespace="namespace2") + context = {"test": "data"} + + uuid1 = provider1.single_value(context) + uuid2 = provider2.single_value(context) + + assert_that(uuid1, is_not(equal_to(uuid2))) + + +def test_random_uuid_generation(): + """Test that random UUIDs are generated correctly.""" + provider = UuidValueProvider() # No variable_name = random UUID + context = {"test": "data"} + + uuid1 = provider.single_value(context) + uuid2 = provider.single_value(context) + + # Random UUIDs should be different + assert_that(uuid1, is_not(equal_to(uuid2))) + # Should be valid UUID format + assert_that( + uuid1, + matches_regexp( + r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$" + ), + ) + assert_that( + uuid2, + matches_regexp( + r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$" + ), + ) + + +def test_random_uuid_with_custom_namespace(): + """Test random UUID generation with custom namespace.""" + provider = UuidValueProvider(namespace="custom-namespace") # No variable_name + context = {"test": "data"} + + uuid1 = provider.single_value(context) + uuid2 = provider.single_value(context) + + # Should be different (random) + assert_that(uuid1, is_not(equal_to(uuid2))) + # Should be valid UUID format + assert_that( + uuid1, + matches_regexp( + r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$" + ), + ) + + +def test_uuid_generation_with_patched_uuid(): + """Test UUID generation with patched uuid.uuid4 for deterministic testing.""" + with patch("uuid.uuid4") as mock_uuid4: + mock_uuid4.return_value = uuid.UUID("12345678-1234-5678-9abc-def123456789") + + provider = UuidValueProvider() # Random UUID + context = {"test": "data"} + + result = provider.single_value(context) + + assert_that(result, equal_to("12345678-1234-5678-9abc-def123456789")) + mock_uuid4.assert_called_once() + + +def test_deterministic_uuid_with_patched_uuid5(): + """Test deterministic UUID generation with patched uuid.uuid5.""" + with patch("uuid.uuid5") as mock_uuid5: + mock_uuid5.return_value = uuid.UUID("87654321-4321-8765-cba9-fed876543210") + + provider = UuidValueProvider(variable_name="test", namespace="test-namespace") + context = {"test": "data"} + + result = provider.single_value(context) + + assert_that(result, equal_to("87654321-4321-8765-cba9-fed876543210")) + # Should be called twice: once for namespace UUID, once for final UUID + assert_that(mock_uuid5.call_count, equal_to(2)) + + +def test_empty_variable_name_handling(): + """Test handling of empty variable_name.""" + provider = UuidValueProvider(variable_name="", namespace="test-namespace") + context = {"test": "data"} + + uuid1 = provider.single_value(context) + uuid2 = provider.single_value(context) + + # Should generate random UUIDs (different each time) + assert_that(uuid1, is_not(equal_to(uuid2))) + + +def test_whitespace_handling(): + """Test that whitespace is properly stripped.""" + provider = UuidValueProvider(variable_name=" test ", namespace=" namespace ") + + assert_that(provider.variable_name, equal_to("test")) + assert_that(provider.namespace, equal_to("namespace")) + + +def test_empty_namespace_uses_default(): + """Test that empty namespace uses default.""" + provider = UuidValueProvider(variable_name="test", namespace="") + + assert_that(provider.namespace, equal_to("nodestream")) + + +def test_many_values_generator(): + """Test that many_values yields the same value as single_value.""" + provider = UuidValueProvider(variable_name="test") + context = {"test": "data"} + + single_result = provider.single_value(context) + many_results = list(provider.many_values(context)) + + assert_that(len(many_results), equal_to(1)) + assert_that(many_results[0], equal_to(single_result)) + + +def test_string_representation(): + """Test the string representation of the provider.""" + provider = UuidValueProvider(variable_name="test", namespace="test-namespace") + + expected = ( + "UuidValueProvider: {'variable_name': 'test', 'namespace': 'test-namespace'}" + ) + assert_that(str(provider), equal_to(expected)) + + +def test_empty_yaml_format(yaml_loader): + """Test empty YAML format: !uuid (no value).""" + test_yaml = """ + test: + empty_uuid: !uuid + """ + + data = yaml.safe_load(test_yaml) + empty_provider = data["test"]["empty_uuid"] + + assert_that(empty_provider.variable_name, equal_to("")) + assert_that(empty_provider.namespace, equal_to("nodestream")) + + +def test_exception_handling_in_uuid_generation(): + """Test that exceptions during UUID generation are handled gracefully.""" + # Test random UUID with exception in uuid4, fallback should succeed + with patch( + "uuid.uuid4", + side_effect=[ + Exception("UUID generation failed"), + uuid.UUID("12345678-1234-5678-9abc-def123456789"), + ], + ): + provider1 = UuidValueProvider() # Random UUID + context = {"test": "data"} + result1 = provider1.single_value(context) + assert_that(result1, equal_to("12345678-1234-5678-9abc-def123456789")) + + # Test deterministic UUID with exception in uuid5, fallback should succeed + with patch("uuid.uuid5", side_effect=Exception("UUID generation failed")): + provider2 = UuidValueProvider(variable_name="test", namespace="test-namespace") + context = {"test": "data"} + # The fallback will use uuid.uuid4, which is not patched here + result2 = provider2.single_value(context) + assert_that( + result2, + matches_regexp( + r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$" + ), + )