diff --git a/.github/codeql/codeql-config.yml b/.github/codeql/codeql-config.yml new file mode 100644 index 0000000..c00638a --- /dev/null +++ b/.github/codeql/codeql-config.yml @@ -0,0 +1,8 @@ +name: "CodeQL Config" + +# Exclude example files from CodeQL analysis +# Examples contain fake credentials for demonstration purposes +paths-ignore: + - "examples/**" + - "**/test_*.py" + - "tests/**" diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 4a7c2ea..670ccc4 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -29,6 +29,8 @@ jobs: languages: python # Queries: security-extended includes more security checks queries: security-extended + # Use custom config to exclude example files + config-file: ./.github/codeql/codeql-config.yml - name: Perform CodeQL Analysis uses: github/codeql-action/analyze@v3 diff --git a/README.md b/README.md index a11603f..765cea5 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ - ✅ **7x faster than pydantic-settings** - High performance built on msgspec - ✅ **Drop-in API compatibility** - Familiar interface, easy migration from pydantic-settings - ✅ **Type-safe** - Full type hints and validation +- ✅ **17+ built-in validators** - Email, URLs, numeric constraints, payment cards, paths, and more - ✅ **.env support** - Fast built-in .env parser (no dependencies) - ✅ **Nested settings** - Support for complex configuration structures - ✅ **Zero dependencies** - Only msgspec required @@ -99,6 +100,100 @@ settings = AppSettings() ## Advanced Usage +### Field Validators + +msgspec-ext provides 17+ built-in validator types for common use cases: + +#### Numeric Constraints + +```python +from msgspec_ext import BaseSettings, PositiveInt, NonNegativeInt + +class ServerSettings(BaseSettings): + port: PositiveInt # Must be > 0 + max_connections: PositiveInt + retry_count: NonNegativeInt # Can be 0 +``` + +**Available numeric types:** +- `PositiveInt`, `NegativeInt`, `NonNegativeInt`, `NonPositiveInt` +- `PositiveFloat`, `NegativeFloat`, `NonNegativeFloat`, `NonPositiveFloat` + +#### String Validators + +```python +from msgspec_ext import BaseSettings, EmailStr, HttpUrl, SecretStr + +class AppSettings(BaseSettings): + admin_email: EmailStr # RFC 5321 validation + api_url: HttpUrl # HTTP/HTTPS only + api_key: SecretStr # Masked in logs/output +``` + +**Available string types:** +- `EmailStr` - Email validation (RFC 5321) +- `HttpUrl` - HTTP/HTTPS URLs only +- `AnyUrl` - Any valid URL scheme +- `SecretStr` - Masks sensitive data in output + +#### Database & Cache Validators + +```python +from msgspec_ext import BaseSettings, PostgresDsn, RedisDsn + +class ConnectionSettings(BaseSettings): + database_url: PostgresDsn # postgresql://user:pass@host/db + cache_url: RedisDsn # redis://localhost:6379 +``` + +#### Payment Card Validation + +```python +from msgspec_ext import BaseSettings, PaymentCardNumber + +class PaymentSettings(BaseSettings): + card: PaymentCardNumber # Luhn algorithm + masking +``` + +**Features:** +- Validates using Luhn algorithm +- Automatically strips spaces/dashes +- Masks card number in repr (shows last 4 digits only) + +#### Path Validators + +```python +from msgspec_ext import BaseSettings, FilePath, DirectoryPath + +class PathSettings(BaseSettings): + config_file: FilePath # Must exist and be a file + data_dir: DirectoryPath # Must exist and be a directory +``` + +**Complete validator list:** + +| Validator | Description | +|-----------|-------------| +| `PositiveInt` | Integer > 0 | +| `NegativeInt` | Integer < 0 | +| `NonNegativeInt` | Integer ≥ 0 | +| `NonPositiveInt` | Integer ≤ 0 | +| `PositiveFloat` | Float > 0.0 | +| `NegativeFloat` | Float < 0.0 | +| `NonNegativeFloat` | Float ≥ 0.0 | +| `NonPositiveFloat` | Float ≤ 0.0 | +| `EmailStr` | Email address (RFC 5321) | +| `HttpUrl` | HTTP/HTTPS URL | +| `AnyUrl` | Any valid URL | +| `SecretStr` | Masked sensitive data | +| `PostgresDsn` | PostgreSQL connection string | +| `RedisDsn` | Redis connection string | +| `PaymentCardNumber` | Credit card with Luhn validation | +| `FilePath` | Existing file path | +| `DirectoryPath` | Existing directory path | + +See `examples/06_validators.py` for complete examples. + ### Nested Configuration ```python diff --git a/examples/06_validators.py b/examples/06_validators.py new file mode 100644 index 0000000..3462fb2 --- /dev/null +++ b/examples/06_validators.py @@ -0,0 +1,367 @@ +"""Example demonstrating custom validators and constrained types. + +This example shows how to use msgspec-ext's built-in validators: +- EmailStr for email validation +- HttpUrl for HTTP/HTTPS URL validation +- PositiveInt, NegativeInt, NonNegativeInt for numeric constraints +- PositiveFloat, NegativeFloat for float constraints +""" + +import os +import tempfile + +from msgspec_ext import ( + AnyUrl, + BaseSettings, + DirectoryPath, + EmailStr, + FilePath, + HttpUrl, + NegativeInt, + NonNegativeInt, + PaymentCardNumber, + PositiveFloat, + PositiveInt, + PostgresDsn, + RedisDsn, + SecretStr, + SettingsConfigDict, +) + + +# Example 1: Email validation +class EmailSettings(BaseSettings): + """Settings with email validation.""" + + model_config = SettingsConfigDict(env_prefix="EMAIL_") + + admin_email: EmailStr + support_email: EmailStr + notifications_email: EmailStr = EmailStr("noreply@example.com") + + +# Example 2: URL validation +class APISettings(BaseSettings): + """Settings with URL validation.""" + + model_config = SettingsConfigDict(env_prefix="API_") + + base_url: HttpUrl # HTTP/HTTPS only + webhook_url: HttpUrl + docs_url: AnyUrl # Any valid URL scheme + + +# Example 3: Numeric constraints +class DatabaseSettings(BaseSettings): + """Settings with numeric constraints.""" + + model_config = SettingsConfigDict(env_prefix="DB_") + + port: PositiveInt # > 0 + max_connections: PositiveInt + min_connections: NonNegativeInt # >= 0 + timeout: PositiveFloat # > 0.0 + + +# Example 4: Secret string (passwords, API keys, tokens) +class SecretSettings(BaseSettings): + """Settings with secret strings (masked in logs/output).""" + + model_config = SettingsConfigDict(env_prefix="SECRET_") + + api_key: SecretStr + database_password: SecretStr + jwt_secret: SecretStr + + +# Example 5: Database and cache DSN validation +class ConnectionSettings(BaseSettings): + """Settings with DSN validation.""" + + model_config = SettingsConfigDict(env_prefix="CONN_") + + postgres_url: PostgresDsn + redis_url: RedisDsn + + +# Example 6: Payment card validation +class PaymentSettings(BaseSettings): + """Settings with payment card validation.""" + + model_config = SettingsConfigDict(env_prefix="PAYMENT_") + + card_number: PaymentCardNumber + + +# Example 7: File and directory path validation +class PathSettings(BaseSettings): + """Settings with path validation.""" + + model_config = SettingsConfigDict(env_prefix="PATH_") + + config_file: FilePath + data_directory: DirectoryPath + + +# Example 8: Combined validators +class AppSettings(BaseSettings): + """Real-world app settings with multiple validators.""" + + # Email validation + admin_email: EmailStr + + # URL validation + api_url: HttpUrl + frontend_url: HttpUrl + + # Secret strings (masked) + api_key: SecretStr + db_password: SecretStr + + # Positive integers + port: PositiveInt + max_workers: PositiveInt + + # Non-negative integers (can be 0) + retry_count: NonNegativeInt + + # Positive floats + timeout: PositiveFloat + rate_limit: PositiveFloat + + # Defaults + debug: bool = False + + +def main(): # noqa: PLR0915 + print("=" * 60) + print("msgspec-ext Validators Demo") + print("=" * 60) + + # Example 1: Email validation + print("\n1. Email Validation") + print("-" * 60) + + os.environ.update( + { + "EMAIL_ADMIN_EMAIL": "admin@example.com", + "EMAIL_SUPPORT_EMAIL": "support@company.org", + } + ) + + email_settings = EmailSettings() + print(f"Admin Email: {email_settings.admin_email}") + print(f"Support Email: {email_settings.support_email}") + print(f"Notifications: {email_settings.notifications_email}") + + # Try invalid email (will raise ValueError) + try: + EmailStr("not-an-email") + except ValueError as e: + print(f"✓ Email validation works: {e}") + + # Example 2: URL validation + print("\n2. URL Validation") + print("-" * 60) + + os.environ.update( + { + "API_BASE_URL": "https://api.example.com", + "API_WEBHOOK_URL": "https://webhook.example.com/events", + "API_DOCS_URL": "https://docs.example.com", + } + ) + + api_settings = APISettings() + print(f"Base URL: {api_settings.base_url}") + print(f"Webhook URL: {api_settings.webhook_url}") + print(f"Docs URL: {api_settings.docs_url}") + + # Try invalid URL (will raise ValueError) + try: + HttpUrl("not a url") + except ValueError as e: + print(f"✓ URL validation works: {e}") + + # Try non-HTTP scheme (will raise ValueError) + try: + HttpUrl("ftp://example.com") + except ValueError as e: + print(f"✓ HTTP scheme validation works: {e}") + + # Example 3: Numeric constraints + print("\n3. Numeric Constraints") + print("-" * 60) + + os.environ.update( + { + "DB_PORT": "5432", + "DB_MAX_CONNECTIONS": "100", + "DB_MIN_CONNECTIONS": "0", + "DB_TIMEOUT": "30.5", + } + ) + + db_settings = DatabaseSettings() + print(f"Port: {db_settings.port} (PositiveInt)") + print(f"Max Connections: {db_settings.max_connections} (PositiveInt)") + print(f"Min Connections: {db_settings.min_connections} (NonNegativeInt)") + print(f"Timeout: {db_settings.timeout}s (PositiveFloat)") + + # Example 4: Secret strings + print("\n4. Secret Strings (Masked in Output)") + print("-" * 60) + + # codeql[py/clear-text-logging-sensitive-data] - example code with fake credentials + os.environ.update( + { + "SECRET_API_KEY": "sk_live_1234567890abcdef", # example fake key + "SECRET_DATABASE_PASSWORD": "super-secret-password-123", # example fake password + "SECRET_JWT_SECRET": "jwt-signing-key-xyz", # example fake secret + } + ) + + secret_settings = SecretSettings() + print(f"API Key: {secret_settings.api_key}") # Masked + print(f"DB Password: {secret_settings.database_password}") # Masked + print(f"JWT Secret: {secret_settings.jwt_secret}") # Masked + print(f"Actual API Key: {secret_settings.api_key.get_secret_value()}") # Unmasked + + # Example 5: Database and Cache DSN validation + print("\n5. Database & Cache DSN Validation") + print("-" * 60) + + os.environ.update( + { + "CONN_POSTGRES_URL": "postgresql://user:pass@localhost:5432/mydb", + "CONN_REDIS_URL": "redis://localhost:6379/0", + } + ) + + conn_settings = ConnectionSettings() + print(f"PostgreSQL: {conn_settings.postgres_url}") + print(f"Redis: {conn_settings.redis_url}") + + # Try invalid DSN + try: + PostgresDsn("mysql://localhost/db") + except ValueError as e: + print(f"✓ PostgreSQL DSN validation works: {e}") + + # Example 6: Payment card validation + print("\n6. Payment Card Validation (Luhn Algorithm)") + print("-" * 60) + + os.environ.update( + { + "PAYMENT_CARD_NUMBER": "4532-0151-1283-0366", # Valid Visa test card + } + ) + + payment_settings = PaymentSettings() + print(f"Card Number: {payment_settings.card_number}") # Shows digits + print(f"Card Repr: {payment_settings.card_number!r}") # Masked! + + # Try invalid card + try: + PaymentCardNumber("1234567890123456") + except ValueError as e: + print(f"✓ Card validation works: {e}") + + # Example 7: File and directory path validation + print("\n7. File & Directory Path Validation") + print("-" * 60) + + # Create temporary file and use current directory + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".conf") as f: + f.write("config=value") + temp_config_file = f.name + + os.environ.update( + { + "PATH_CONFIG_FILE": temp_config_file, + "PATH_DATA_DIRECTORY": ".", # Current directory + } + ) + + path_settings = PathSettings() + print(f"Config File: {path_settings.config_file}") + print(f"Data Directory: {path_settings.data_directory}") + + # Cleanup + os.unlink(temp_config_file) + + # Try nonexistent file + try: + FilePath("/nonexistent/file.txt") + except ValueError as e: + print(f"✓ Path validation works: {e}") + + # Example 8: Real-World Combined validators + print("\n8. Real-World App Settings") + print("-" * 60) + + # Set environment variables + # codeql[py/clear-text-logging-sensitive-data] - example code with fake credentials + os.environ.update( + { + "ADMIN_EMAIL": "admin@myapp.com", + "API_URL": "https://api.myapp.com", + "FRONTEND_URL": "https://myapp.com", + "API_KEY": "sk_prod_secret_key_123", # example fake key + "DB_PASSWORD": "postgres_password_456", # example fake password + "PORT": "8000", + "MAX_WORKERS": "4", + "RETRY_COUNT": "3", + "TIMEOUT": "30.0", + "RATE_LIMIT": "100.0", + } + ) + + app_settings = AppSettings() + print(f"Admin: {app_settings.admin_email}") + print(f"API: {app_settings.api_url}") + print(f"Frontend: {app_settings.frontend_url}") + print(f"API Key: {app_settings.api_key}") # Masked! + print(f"DB Password: {app_settings.db_password}") # Masked! + print(f"Port: {app_settings.port}") + print(f"Workers: {app_settings.max_workers}") + print(f"Retries: {app_settings.retry_count}") + print(f"Timeout: {app_settings.timeout}s") + print(f"Rate Limit: {app_settings.rate_limit}/s") + print(f"Debug: {app_settings.debug}") + + # Example 9: Validation errors + print("\n9. Validation Error Examples") + print("-" * 60) + + # Negative port (invalid) + os.environ["PORT"] = "-1" + try: + AppSettings() + except ValueError as e: + print(f"✓ Negative port rejected: {e}") + + # Zero for PositiveInt (invalid) + os.environ["PORT"] = "0" + try: + AppSettings() + except ValueError as e: + print(f"✓ Zero rejected for PositiveInt: {e}") + + # Zero for NonNegativeInt (valid!) + os.environ["PORT"] = "8000" # Valid again + os.environ["RETRY_COUNT"] = "0" + try: + settings = AppSettings() + print(f"✓ Zero allowed for NonNegativeInt: {settings.retry_count}") + except ValueError as e: + print(f"Unexpected error: {e}") + + print("\n" + "=" * 60) + print("All validator examples completed successfully!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/src/msgspec_ext/__init__.py b/src/msgspec_ext/__init__.py index 090fc35..37a3dc0 100644 --- a/src/msgspec_ext/__init__.py +++ b/src/msgspec_ext/__init__.py @@ -1,6 +1,42 @@ from .settings import BaseSettings, SettingsConfigDict +from .types import ( + AnyUrl, + DirectoryPath, + EmailStr, + FilePath, + HttpUrl, + NegativeFloat, + NegativeInt, + NonNegativeFloat, + NonNegativeInt, + NonPositiveFloat, + NonPositiveInt, + PaymentCardNumber, + PositiveFloat, + PositiveInt, + PostgresDsn, + RedisDsn, + SecretStr, +) __all__ = [ + "AnyUrl", "BaseSettings", + "DirectoryPath", + "EmailStr", + "FilePath", + "HttpUrl", + "NegativeFloat", + "NegativeInt", + "NonNegativeFloat", + "NonNegativeInt", + "NonPositiveFloat", + "NonPositiveInt", + "PaymentCardNumber", + "PositiveFloat", + "PositiveInt", + "PostgresDsn", + "RedisDsn", + "SecretStr", "SettingsConfigDict", ] diff --git a/src/msgspec_ext/settings.py b/src/msgspec_ext/settings.py index ef7766a..639a894 100644 --- a/src/msgspec_ext/settings.py +++ b/src/msgspec_ext/settings.py @@ -1,15 +1,61 @@ """Optimized settings management using msgspec.Struct and bulk JSON decoding.""" import os -from typing import Any, ClassVar, Union, get_args, get_origin +from typing import Annotated, Any, ClassVar, Union, get_args, get_origin import msgspec from msgspec_ext.fast_dotenv import load_dotenv +from msgspec_ext.types import ( + AnyUrl, + DirectoryPath, + EmailStr, + FilePath, + HttpUrl, + PaymentCardNumber, + PostgresDsn, + RedisDsn, + SecretStr, +) __all__ = ["BaseSettings", "SettingsConfigDict"] +def _dec_hook(typ: type, obj: Any) -> Any: + """Decoding hook for custom types. + + Handles conversion from JSON-decoded values to custom types like EmailStr, HttpUrl, etc. + + Args: + typ: The target type to decode to + obj: The JSON-decoded object + + Returns: + Converted object of type typ + + Raises: + NotImplementedError: If type is not supported + """ + # Handle our custom string types + custom_types = ( + EmailStr, + HttpUrl, + AnyUrl, + SecretStr, + PostgresDsn, + RedisDsn, + PaymentCardNumber, + FilePath, + DirectoryPath, + ) + if typ in custom_types: + if isinstance(obj, str): + return typ(obj) + + # If we don't handle it, let msgspec raise an error + raise NotImplementedError(f"Type {typ} unsupported in dec_hook") + + class SettingsConfigDict(msgspec.Struct): """Configuration options for BaseSettings.""" @@ -195,7 +241,7 @@ def _decode_from_dict(cls, struct_cls, values: dict[str, Any]): encoder_decoder = cls._encoder_cache.get(struct_cls) if encoder_decoder is None: encoder = msgspec.json.Encoder() - decoder = msgspec.json.Decoder(type=struct_cls) + decoder = msgspec.json.Decoder(type=struct_cls, dec_hook=_dec_hook) encoder_decoder = (encoder, decoder) cls._encoder_cache[struct_cls] = encoder_decoder cls._decoder_cache[struct_cls] = encoder_decoder @@ -306,7 +352,7 @@ def _get_env_name(cls, field_name: str) -> str: return env_name @classmethod - def _preprocess_env_value(cls, env_value: str, field_type: type) -> Any: # noqa: C901 + def _preprocess_env_value(cls, env_value: str, field_type: type) -> Any: # noqa: C901, PLR0912 """Convert environment variable string to JSON-compatible type. Ultra-optimized to minimize type introspection overhead with caching. @@ -344,8 +390,21 @@ def _preprocess_env_value(cls, env_value: str, field_type: type) -> Any: # noqa except ValueError as e: raise ValueError(f"Cannot convert '{env_value}' to float") from e - # Only use typing introspection for complex types (Union, Optional, etc.) + # Only use typing introspection for complex types (Union, Optional, Annotated, etc.) origin = get_origin(field_type) + + # Handle Annotated types (e.g., Annotated[int, Meta(...)]) + # For Annotated, get_origin returns typing.Annotated and get_args()[0] is the base type + if origin is not None and ( + origin is Annotated or str(origin) == "typing.Annotated" + ): + args = get_args(field_type) + if args: + base_type = args[0] + # Cache and recursively process with base type + cls._type_cache[field_type] = base_type + return cls._preprocess_env_value(env_value, base_type) + if origin is Union: args = get_args(field_type) non_none = [a for a in args if a is not type(None)] diff --git a/src/msgspec_ext/types.py b/src/msgspec_ext/types.py new file mode 100644 index 0000000..67e633d --- /dev/null +++ b/src/msgspec_ext/types.py @@ -0,0 +1,508 @@ +"""Custom types and validators for msgspec-ext. + +Provides Pydantic-like type aliases and validation types built on msgspec.Meta. + +Example: + from msgspec_ext import BaseSettings + from msgspec_ext.types import EmailStr, HttpUrl, PositiveInt + + class AppSettings(BaseSettings): + email: EmailStr + api_url: HttpUrl + max_connections: PositiveInt +""" + +import os +import re +from typing import Annotated + +import msgspec + +__all__ = [ + "AnyUrl", + "DirectoryPath", + "EmailStr", + "FilePath", + "HttpUrl", + "NegativeFloat", + "NegativeInt", + "NonNegativeFloat", + "NonNegativeInt", + "NonPositiveFloat", + "NonPositiveInt", + "PaymentCardNumber", + "PositiveFloat", + "PositiveInt", + "PostgresDsn", + "RedisDsn", + "SecretStr", +] + +# ============================================================================== +# Numeric Constraint Types +# ============================================================================== + +# Integer types +PositiveInt = Annotated[int, msgspec.Meta(gt=0, description="Integer greater than 0")] +NegativeInt = Annotated[int, msgspec.Meta(lt=0, description="Integer less than 0")] +NonNegativeInt = Annotated[int, msgspec.Meta(ge=0, description="Integer >= 0")] +NonPositiveInt = Annotated[int, msgspec.Meta(le=0, description="Integer <= 0")] + +# Float types +PositiveFloat = Annotated[ + float, msgspec.Meta(gt=0.0, description="Float greater than 0.0") +] +NegativeFloat = Annotated[ + float, msgspec.Meta(lt=0.0, description="Float less than 0.0") +] +NonNegativeFloat = Annotated[float, msgspec.Meta(ge=0.0, description="Float >= 0.0")] +NonPositiveFloat = Annotated[float, msgspec.Meta(le=0.0, description="Float <= 0.0")] + + +# ============================================================================== +# String Validation Types with Custom Logic +# ============================================================================== + +# Email validation constants +_EMAIL_MIN_LENGTH = 3 +_EMAIL_MAX_LENGTH = 320 # RFC 5321 + +# URL validation constants +_URL_MAX_LENGTH = 2083 # IE limit, de facto standard + +# Payment card validation constants +_CARD_MIN_LENGTH = 13 +_CARD_MAX_LENGTH = 19 +_CARD_MASK_LAST_DIGITS = 4 +_LUHN_DOUBLE_THRESHOLD = 9 + +# Email regex pattern (simplified but covers most common cases) +# More strict than basic patterns, requires @ and domain with TLD +_EMAIL_PATTERN = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$" + +# URL regex patterns +_HTTP_URL_PATTERN = r"^https?://[^\s/$.?#].[^\s]*$" +_ANY_URL_PATTERN = r"^[a-zA-Z][a-zA-Z0-9+.-]*:.+$" + + +class _EmailStr(str): + """Email string type with validation. + + Validates email format using regex pattern. + Compatible with msgspec encoding/decoding. + """ + + __slots__ = () + + def __new__(cls, value: str) -> "_EmailStr": + """Create and validate email string. + + Args: + value: Email address string + + Returns: + Validated email string + + Raises: + ValueError: If email format is invalid + """ + if not isinstance(value, str): + raise TypeError(f"Expected str, got {type(value).__name__}") + + # Strip whitespace + value = value.strip() + + # Validate length + if not value or len(value) < _EMAIL_MIN_LENGTH: + raise ValueError(f"Email must be at least {_EMAIL_MIN_LENGTH} characters") + if len(value) > _EMAIL_MAX_LENGTH: + raise ValueError(f"Email must be at most {_EMAIL_MAX_LENGTH} characters") + + # Validate format + if not re.match(_EMAIL_PATTERN, value): + raise ValueError(f"Invalid email format: {value!r}") + + return str.__new__(cls, value) + + def __repr__(self) -> str: + return f"EmailStr({str.__repr__(self)})" + + +class _HttpUrl(str): + """HTTP/HTTPS URL string type with validation. + + Validates URL format and scheme (http or https only). + Compatible with msgspec encoding/decoding. + """ + + __slots__ = () + + def __new__(cls, value: str) -> "_HttpUrl": + """Create and validate HTTP URL string. + + Args: + value: HTTP/HTTPS URL string + + Returns: + Validated URL string + + Raises: + ValueError: If URL format is invalid or scheme is not http/https + """ + if not isinstance(value, str): + raise TypeError(f"Expected str, got {type(value).__name__}") + + # Strip whitespace + value = value.strip() + + # Validate length + if not value: + raise ValueError("URL cannot be empty") + if len(value) > _URL_MAX_LENGTH: + raise ValueError(f"URL must be at most {_URL_MAX_LENGTH} characters") + + # Validate format + if not re.match(_HTTP_URL_PATTERN, value, re.IGNORECASE): + raise ValueError(f"Invalid HTTP URL format: {value!r}") + + # Ensure http/https scheme + lower_value = value.lower() + if not ( + lower_value.startswith("http://") or lower_value.startswith("https://") + ): + raise ValueError(f"URL must use http or https scheme: {value!r}") + + return str.__new__(cls, value) + + def __repr__(self) -> str: + return f"HttpUrl({str.__repr__(self)})" + + +class _AnyUrl(str): + """URL string type with validation for any scheme. + + Validates URL format for any valid scheme (http, https, ftp, ws, etc). + Compatible with msgspec encoding/decoding. + """ + + __slots__ = () + + def __new__(cls, value: str) -> "_AnyUrl": + """Create and validate URL string. + + Args: + value: URL string with any scheme + + Returns: + Validated URL string + + Raises: + ValueError: If URL format is invalid + """ + if not isinstance(value, str): + raise TypeError(f"Expected str, got {type(value).__name__}") + + # Strip whitespace + value = value.strip() + + # Validate length + if not value: + raise ValueError("URL cannot be empty") + + # Validate format + if not re.match(_ANY_URL_PATTERN, value, re.IGNORECASE): + raise ValueError(f"Invalid URL format: {value!r}") + + return str.__new__(cls, value) + + def __repr__(self) -> str: + return f"AnyUrl({str.__repr__(self)})" + + +class _SecretStr(str): + """Secret string type that masks the value in string representation. + + Useful for passwords, API keys, tokens, and other sensitive data. + The actual value is accessible but hidden in logs and reprs. + Compatible with msgspec encoding/decoding. + """ + + __slots__ = () + + def __new__(cls, value: str) -> "_SecretStr": + """Create secret string. + + Args: + value: The secret string value + + Returns: + Secret string instance + + Raises: + TypeError: If value is not a string + """ + if not isinstance(value, str): + raise TypeError(f"Expected str, got {type(value).__name__}") + + return str.__new__(cls, value) + + def __repr__(self) -> str: + """Return masked representation.""" + return "SecretStr('**********')" + + def __str__(self) -> str: + """Return masked string representation.""" + return "**********" + + def get_secret_value(self) -> str: + """Get the actual secret value. + + Returns: + The unmasked secret string + """ + return str.__str__(self) + + +class _PostgresDsn(str): + """PostgreSQL DSN (Data Source Name) validation. + + Validates PostgreSQL connection strings. + Format: postgresql://user:password@host:port/database + """ + + __slots__ = () + + def __new__(cls, value: str) -> "_PostgresDsn": + """Create and validate PostgreSQL DSN. + + Args: + value: PostgreSQL connection string + + Returns: + Validated DSN string + + Raises: + ValueError: If DSN format is invalid + """ + if not isinstance(value, str): + raise TypeError(f"Expected str, got {type(value).__name__}") + + value = value.strip() + + # Check scheme + if not value.lower().startswith(("postgresql://", "postgres://")): + raise ValueError( + "PostgreSQL DSN must start with 'postgresql://' or 'postgres://'" + ) + + # Basic validation: must have a host/database part after scheme + # Format can be: postgresql://host/db or postgresql://user:pass@host/db + scheme_end = value.find("://") + 3 + remainder = value[scheme_end:] + if not remainder or "/" not in remainder: + raise ValueError("Invalid PostgreSQL DSN format") + + return str.__new__(cls, value) + + def __repr__(self) -> str: + return f"PostgresDsn({str.__repr__(self)})" + + +class _RedisDsn(str): + """Redis DSN (Data Source Name) validation. + + Validates Redis connection strings. + Format: redis://[user:password@]host:port[/database] + """ + + __slots__ = () + + def __new__(cls, value: str) -> "_RedisDsn": + """Create and validate Redis DSN. + + Args: + value: Redis connection string + + Returns: + Validated DSN string + + Raises: + ValueError: If DSN format is invalid + """ + if not isinstance(value, str): + raise TypeError(f"Expected str, got {type(value).__name__}") + + value = value.strip() + + # Check scheme + if not value.lower().startswith(("redis://", "rediss://")): + raise ValueError("Redis DSN must start with 'redis://' or 'rediss://'") + + return str.__new__(cls, value) + + def __repr__(self) -> str: + return f"RedisDsn({str.__repr__(self)})" + + +class _PaymentCardNumber(str): + """Payment card number validation using Luhn algorithm. + + Validates credit/debit card numbers. + Supports major card types: Visa, Mastercard, Amex, Discover, etc. + """ + + __slots__ = () + + def __new__(cls, value: str) -> "_PaymentCardNumber": + """Create and validate payment card number. + + Args: + value: Card number (with or without spaces/dashes) + + Returns: + Validated card number + + Raises: + ValueError: If card number is invalid + """ + if not isinstance(value, str): + raise TypeError(f"Expected str, got {type(value).__name__}") + + # Remove spaces and dashes + digits = value.replace(" ", "").replace("-", "") + + # Check all digits + if not digits.isdigit(): + raise ValueError("Card number must contain only digits") + + # Check length (13-19 digits for most cards) + if not _CARD_MIN_LENGTH <= len(digits) <= _CARD_MAX_LENGTH: + raise ValueError( + f"Card number must be {_CARD_MIN_LENGTH}-{_CARD_MAX_LENGTH} digits" + ) + + # Luhn algorithm validation + if not cls._luhn_check(digits): + raise ValueError("Invalid card number (failed Luhn check)") + + return str.__new__(cls, digits) + + @staticmethod + def _luhn_check(card_number: str) -> bool: + """Validate card number using Luhn algorithm. + + Args: + card_number: Card number string (digits only) + + Returns: + True if valid, False otherwise + """ + total = 0 + reverse_digits = card_number[::-1] + + for i, digit in enumerate(reverse_digits): + n = int(digit) + if i % 2 == 1: # Every second digit from the right + n *= 2 + if n > _LUHN_DOUBLE_THRESHOLD: + n -= _LUHN_DOUBLE_THRESHOLD + total += n + + return total % 10 == 0 + + def __repr__(self) -> str: + # Mask all but last 4 digits + if len(self) >= _CARD_MASK_LAST_DIGITS: + masked = ( + "*" * (len(self) - _CARD_MASK_LAST_DIGITS) + + str.__str__(self)[-_CARD_MASK_LAST_DIGITS:] + ) + else: + masked = "*" * len(self) + return f"PaymentCardNumber('{masked}')" + + +class _FilePath(str): + """File path validation - must exist and be a file. + + Validates that the path exists and points to a file (not directory). + """ + + __slots__ = () + + def __new__(cls, value: str) -> "_FilePath": + """Create and validate file path. + + Args: + value: Path to file + + Returns: + Validated file path + + Raises: + ValueError: If path doesn't exist or is not a file + """ + if not isinstance(value, str): + raise TypeError(f"Expected str, got {type(value).__name__}") + + value = value.strip() + + if not os.path.exists(value): + raise ValueError(f"Path does not exist: {value}") + + if not os.path.isfile(value): + raise ValueError(f"Path is not a file: {value}") + + return str.__new__(cls, value) + + def __repr__(self) -> str: + return f"FilePath({str.__repr__(self)})" + + +class _DirectoryPath(str): + """Directory path validation - must exist and be a directory. + + Validates that the path exists and points to a directory (not file). + """ + + __slots__ = () + + def __new__(cls, value: str) -> "_DirectoryPath": + """Create and validate directory path. + + Args: + value: Path to directory + + Returns: + Validated directory path + + Raises: + ValueError: If path doesn't exist or is not a directory + """ + if not isinstance(value, str): + raise TypeError(f"Expected str, got {type(value).__name__}") + + value = value.strip() + + if not os.path.exists(value): + raise ValueError(f"Path does not exist: {value}") + + if not os.path.isdir(value): + raise ValueError(f"Path is not a directory: {value}") + + return str.__new__(cls, value) + + def __repr__(self) -> str: + return f"DirectoryPath({str.__repr__(self)})" + + +# Export as type aliases for better DX +EmailStr = _EmailStr +HttpUrl = _HttpUrl +AnyUrl = _AnyUrl +SecretStr = _SecretStr +PostgresDsn = _PostgresDsn +RedisDsn = _RedisDsn +PaymentCardNumber = _PaymentCardNumber +FilePath = _FilePath +DirectoryPath = _DirectoryPath diff --git a/tests/test_types.py b/tests/test_types.py new file mode 100644 index 0000000..baa4a0f --- /dev/null +++ b/tests/test_types.py @@ -0,0 +1,681 @@ +"""Tests for custom types and validators in msgspec_ext.types.""" + +import pytest + +from msgspec_ext.types import ( + AnyUrl, + DirectoryPath, + EmailStr, + FilePath, + HttpUrl, + NegativeFloat, + NegativeInt, + NonNegativeFloat, + NonNegativeInt, + NonPositiveFloat, + NonPositiveInt, + PaymentCardNumber, + PositiveFloat, + PositiveInt, + PostgresDsn, + RedisDsn, + SecretStr, +) + +# ============================================================================== +# Numeric Type Tests +# ============================================================================== + + +class TestPositiveInt: + """Tests for PositiveInt type.""" + + def test_valid_positive_integers(self): + """Should accept positive integers.""" + import msgspec + + assert msgspec.json.decode(b"1", type=PositiveInt) == 1 + assert msgspec.json.decode(b"100", type=PositiveInt) == 100 + assert msgspec.json.decode(b"999999", type=PositiveInt) == 999999 + + def test_reject_zero(self): + """Should reject zero (not strictly positive).""" + import msgspec + + with pytest.raises(msgspec.ValidationError, match="Expected `int` >= 1"): + msgspec.json.decode(b"0", type=PositiveInt) + + def test_reject_negative(self): + """Should reject negative integers.""" + import msgspec + + with pytest.raises(msgspec.ValidationError): + msgspec.json.decode(b"-1", type=PositiveInt) + with pytest.raises(msgspec.ValidationError): + msgspec.json.decode(b"-100", type=PositiveInt) + + +class TestNegativeInt: + """Tests for NegativeInt type.""" + + def test_valid_negative_integers(self): + """Should accept negative integers.""" + import msgspec + + assert msgspec.json.decode(b"-1", type=NegativeInt) == -1 + assert msgspec.json.decode(b"-100", type=NegativeInt) == -100 + + def test_reject_zero(self): + """Should reject zero (not strictly negative).""" + import msgspec + + with pytest.raises(msgspec.ValidationError): + msgspec.json.decode(b"0", type=NegativeInt) + + def test_reject_positive(self): + """Should reject positive integers.""" + import msgspec + + with pytest.raises(msgspec.ValidationError): + msgspec.json.decode(b"1", type=NegativeInt) + + +class TestNonNegativeInt: + """Tests for NonNegativeInt type.""" + + def test_valid_non_negative_integers(self): + """Should accept zero and positive integers.""" + import msgspec + + assert msgspec.json.decode(b"0", type=NonNegativeInt) == 0 + assert msgspec.json.decode(b"1", type=NonNegativeInt) == 1 + assert msgspec.json.decode(b"100", type=NonNegativeInt) == 100 + + def test_reject_negative(self): + """Should reject negative integers.""" + import msgspec + + with pytest.raises(msgspec.ValidationError): + msgspec.json.decode(b"-1", type=NonNegativeInt) + + +class TestNonPositiveInt: + """Tests for NonPositiveInt type.""" + + def test_valid_non_positive_integers(self): + """Should accept zero and negative integers.""" + import msgspec + + assert msgspec.json.decode(b"0", type=NonPositiveInt) == 0 + assert msgspec.json.decode(b"-1", type=NonPositiveInt) == -1 + assert msgspec.json.decode(b"-100", type=NonPositiveInt) == -100 + + def test_reject_positive(self): + """Should reject positive integers.""" + import msgspec + + with pytest.raises(msgspec.ValidationError): + msgspec.json.decode(b"1", type=NonPositiveInt) + + +class TestPositiveFloat: + """Tests for PositiveFloat type.""" + + def test_valid_positive_floats(self): + """Should accept positive floats.""" + import msgspec + + assert msgspec.json.decode(b"0.1", type=PositiveFloat) == 0.1 + assert msgspec.json.decode(b"1.0", type=PositiveFloat) == 1.0 + assert msgspec.json.decode(b"99.99", type=PositiveFloat) == 99.99 + + def test_reject_zero(self): + """Should reject zero (not strictly positive).""" + import msgspec + + with pytest.raises(msgspec.ValidationError): + msgspec.json.decode(b"0.0", type=PositiveFloat) + + def test_reject_negative(self): + """Should reject negative floats.""" + import msgspec + + with pytest.raises(msgspec.ValidationError): + msgspec.json.decode(b"-0.1", type=PositiveFloat) + + +class TestNegativeFloat: + """Tests for NegativeFloat type.""" + + def test_valid_negative_floats(self): + """Should accept negative floats.""" + import msgspec + + assert msgspec.json.decode(b"-0.1", type=NegativeFloat) == -0.1 + assert msgspec.json.decode(b"-99.99", type=NegativeFloat) == -99.99 + + def test_reject_zero(self): + """Should reject zero (not strictly negative).""" + import msgspec + + with pytest.raises(msgspec.ValidationError): + msgspec.json.decode(b"0.0", type=NegativeFloat) + + def test_reject_positive(self): + """Should reject positive floats.""" + import msgspec + + with pytest.raises(msgspec.ValidationError): + msgspec.json.decode(b"0.1", type=NegativeFloat) + + +class TestNonNegativeFloat: + """Tests for NonNegativeFloat type.""" + + def test_valid_non_negative_floats(self): + """Should accept zero and positive floats.""" + import msgspec + + assert msgspec.json.decode(b"0.0", type=NonNegativeFloat) == 0.0 + assert msgspec.json.decode(b"0.1", type=NonNegativeFloat) == 0.1 + assert msgspec.json.decode(b"99.99", type=NonNegativeFloat) == 99.99 + + def test_reject_negative(self): + """Should reject negative floats.""" + import msgspec + + with pytest.raises(msgspec.ValidationError): + msgspec.json.decode(b"-0.1", type=NonNegativeFloat) + + +class TestNonPositiveFloat: + """Tests for NonPositiveFloat type.""" + + def test_valid_non_positive_floats(self): + """Should accept zero and negative floats.""" + import msgspec + + assert msgspec.json.decode(b"0.0", type=NonPositiveFloat) == 0.0 + assert msgspec.json.decode(b"-0.1", type=NonPositiveFloat) == -0.1 + assert msgspec.json.decode(b"-99.99", type=NonPositiveFloat) == -99.99 + + def test_reject_positive(self): + """Should reject positive floats.""" + import msgspec + + with pytest.raises(msgspec.ValidationError): + msgspec.json.decode(b"0.1", type=NonPositiveFloat) + + +# ============================================================================== +# EmailStr Tests +# ============================================================================== + + +class TestEmailStr: + """Tests for EmailStr type.""" + + def test_valid_emails(self): + """Should accept valid email formats.""" + valid_emails = [ + "user@example.com", + "test.user@example.com", + "user+tag@example.co.uk", + "user123@test-domain.com", + "a@b.co", + ] + for email in valid_emails: + result = EmailStr(email) + assert str(result) == email.strip() + + def test_email_strips_whitespace(self): + """Should strip leading/trailing whitespace.""" + assert str(EmailStr(" user@example.com ")) == "user@example.com" + + def test_reject_invalid_emails(self): + """Should reject invalid email formats.""" + invalid_emails = [ + "", + "not-an-email", + "@example.com", + "user@", + "user @example.com", + "user@.com", + "user@domain", + "a" * 321, # Too long (> 320 chars) + ] + for email in invalid_emails: + with pytest.raises(ValueError): + EmailStr(email) + + def test_email_too_short(self): + """Should reject emails shorter than 3 characters.""" + with pytest.raises(ValueError, match="at least 3 characters"): + EmailStr("a@") + + def test_email_too_long(self): + """Should reject emails longer than 320 characters.""" + long_email = "a" * 310 + "@example.com" # 310 + 12 = 322 chars + with pytest.raises(ValueError, match="at most 320 characters"): + EmailStr(long_email) + + def test_email_type_error(self): + """Should reject non-string inputs.""" + with pytest.raises(TypeError): + EmailStr(123) # type: ignore + + +# ============================================================================== +# HttpUrl Tests +# ============================================================================== + + +class TestHttpUrl: + """Tests for HttpUrl type.""" + + def test_valid_http_urls(self): + """Should accept valid HTTP URLs.""" + valid_urls = [ + "http://example.com", + "https://example.com", + "http://example.com/path", + "https://example.com/path?query=value", + "http://subdomain.example.com:8080/path", + ] + for url in valid_urls: + result = HttpUrl(url) + assert str(result) == url.strip() + + def test_url_strips_whitespace(self): + """Should strip leading/trailing whitespace.""" + assert str(HttpUrl(" https://example.com ")) == "https://example.com" + + def test_reject_non_http_schemes(self): + """Should reject non-HTTP/HTTPS schemes.""" + invalid_urls = [ + "ftp://example.com", + "ws://example.com", + "file:///path/to/file", + ] + for url in invalid_urls: + with pytest.raises(ValueError): + HttpUrl(url) + + def test_reject_invalid_urls(self): + """Should reject invalid URL formats.""" + invalid_urls = [ + "", + "not a url", + "http://", + "://example.com", + ] + for url in invalid_urls: + with pytest.raises(ValueError): + HttpUrl(url) + + def test_url_too_long(self): + """Should reject URLs longer than 2083 characters.""" + long_url = "http://example.com/" + "a" * 2100 + with pytest.raises(ValueError, match="at most 2083 characters"): + HttpUrl(long_url) + + def test_url_type_error(self): + """Should reject non-string inputs.""" + with pytest.raises(TypeError): + HttpUrl(123) # type: ignore + + +# ============================================================================== +# AnyUrl Tests +# ============================================================================== + + +class TestAnyUrl: + """Tests for AnyUrl type.""" + + def test_valid_any_urls(self): + """Should accept URLs with any valid scheme.""" + valid_urls = [ + "http://example.com", + "https://example.com", + "ftp://ftp.example.com", + "ws://websocket.example.com", + "wss://secure.websocket.com", + "file:///path/to/file", + "mailto:user@example.com", + ] + for url in valid_urls: + result = AnyUrl(url) + assert str(result) == url.strip() + + def test_url_strips_whitespace(self): + """Should strip leading/trailing whitespace.""" + assert str(AnyUrl(" ftp://example.com ")) == "ftp://example.com" + + def test_reject_invalid_urls(self): + """Should reject invalid URL formats.""" + invalid_urls = [ + "", + "not a url", + "://example.com", + "http//example.com", # Missing colon + ] + for url in invalid_urls: + with pytest.raises(ValueError): + AnyUrl(url) + + def test_url_type_error(self): + """Should reject non-string inputs.""" + with pytest.raises(TypeError): + AnyUrl(123) # type: ignore + + +# ============================================================================== +# SecretStr Tests +# ============================================================================== + + +class TestSecretStr: + """Tests for SecretStr type.""" + + def test_create_secret(self): + """Should create secret string from regular string.""" + secret = SecretStr("my-secret-password") + assert isinstance(secret, str) + assert secret.get_secret_value() == "my-secret-password" + + def test_repr_masking(self): + """Should mask value in repr.""" + secret = SecretStr("my-secret-password") + assert repr(secret) == "SecretStr('**********')" + assert "my-secret-password" not in repr(secret) + + def test_str_masking(self): + """Should mask value in str().""" + secret = SecretStr("my-secret-password") + assert str(secret) == "**********" + assert "my-secret-password" not in str(secret) + + def test_get_secret_value(self): + """Should allow accessing actual secret value.""" + secret = SecretStr("my-secret-password") + assert secret.get_secret_value() == "my-secret-password" + + def test_empty_secret(self): + """Should allow empty secrets.""" + secret = SecretStr("") + assert secret.get_secret_value() == "" + assert str(secret) == "**********" + + def test_secret_type_error(self): + """Should reject non-string inputs.""" + with pytest.raises(TypeError): + SecretStr(123) # type: ignore + + def test_secret_in_print(self): + """Should be masked when printed.""" + secret = SecretStr("super-secret") + output = f"Password: {secret}" + assert output == "Password: **********" + assert "super-secret" not in output + + def test_secret_in_dict(self): + """Should be masked in dict repr.""" + config = {"password": SecretStr("secret123")} + dict_str = str(config) + assert "**********" in dict_str + assert "secret123" not in dict_str + + +# ============================================================================== +# PostgresDsn Tests +# ============================================================================== + + +class TestPostgresDsn: + """Tests for PostgresDsn type.""" + + def test_valid_postgres_dsn(self): + """Should accept valid PostgreSQL DSN.""" + valid_dsns = [ + "postgresql://user:pass@localhost:5432/dbname", + "postgres://user:pass@localhost:5432/dbname", + "postgresql://user@localhost/dbname", + "postgres://localhost/db", + ] + for dsn in valid_dsns: + result = PostgresDsn(dsn) + assert str(result) == dsn.strip() + + def test_dsn_strips_whitespace(self): + """Should strip leading/trailing whitespace.""" + dsn = " postgresql://user:pass@localhost/db " + result = PostgresDsn(dsn) + assert str(result) == "postgresql://user:pass@localhost/db" + + def test_reject_invalid_scheme(self): + """Should reject non-PostgreSQL schemes.""" + invalid_dsns = [ + "mysql://user:pass@localhost/db", + "redis://localhost:6379", + "http://localhost", + ] + for dsn in invalid_dsns: + with pytest.raises(ValueError, match="must start with"): + PostgresDsn(dsn) + + def test_reject_invalid_format(self): + """Should reject invalid DSN format.""" + invalid_dsns = [ + "postgresql://", # Empty + "postgresql://localhost", # Missing database (no slash) + ] + for dsn in invalid_dsns: + with pytest.raises(ValueError, match="Invalid PostgreSQL DSN"): + PostgresDsn(dsn) + + def test_dsn_type_error(self): + """Should reject non-string inputs.""" + with pytest.raises(TypeError): + PostgresDsn(123) # type: ignore + + +# ============================================================================== +# RedisDsn Tests +# ============================================================================== + + +class TestRedisDsn: + """Tests for RedisDsn type.""" + + def test_valid_redis_dsn(self): + """Should accept valid Redis DSN.""" + valid_dsns = [ + "redis://localhost:6379", + "redis://localhost:6379/0", + "redis://user:pass@localhost:6379", + "rediss://localhost:6380", # SSL + ] + for dsn in valid_dsns: + result = RedisDsn(dsn) + assert str(result) == dsn.strip() + + def test_dsn_strips_whitespace(self): + """Should strip leading/trailing whitespace.""" + dsn = " redis://localhost:6379 " + result = RedisDsn(dsn) + assert str(result) == "redis://localhost:6379" + + def test_reject_invalid_scheme(self): + """Should reject non-Redis schemes.""" + invalid_dsns = [ + "postgresql://localhost/db", + "http://localhost", + "mysql://localhost", + ] + for dsn in invalid_dsns: + with pytest.raises(ValueError, match="must start with"): + RedisDsn(dsn) + + def test_dsn_type_error(self): + """Should reject non-string inputs.""" + with pytest.raises(TypeError): + RedisDsn(123) # type: ignore + + +# ============================================================================== +# PaymentCardNumber Tests +# ============================================================================== + + +class TestPaymentCardNumber: + """Tests for PaymentCardNumber type.""" + + def test_valid_card_numbers(self): + """Should accept valid card numbers.""" + # Valid test card numbers (Luhn-valid) + valid_cards = [ + "4532015112830366", # Visa + "5425233430109903", # Mastercard + "374245455400126", # Amex + "6011000991300009", # Discover + ] + for card in valid_cards: + result = PaymentCardNumber(card) + assert len(result) >= 13 + + def test_card_with_spaces(self): + """Should accept card numbers with spaces.""" + card = "4532 0151 1283 0366" + result = PaymentCardNumber(card) + assert " " not in result # Spaces removed + assert len(result) == 16 + + def test_card_with_dashes(self): + """Should accept card numbers with dashes.""" + card = "4532-0151-1283-0366" + result = PaymentCardNumber(card) + assert "-" not in result # Dashes removed + assert len(result) == 16 + + def test_reject_invalid_luhn(self): + """Should reject card numbers that fail Luhn check.""" + invalid_cards = [ + "4532015112830367", # Last digit wrong + "1234567890123456", # Invalid + ] + for card in invalid_cards: + with pytest.raises(ValueError, match="failed Luhn check"): + PaymentCardNumber(card) + + def test_reject_non_digits(self): + """Should reject card numbers with non-digit characters.""" + invalid_cards = [ + "4532-0151-ABCD-0366", + "card number 123", + ] + for card in invalid_cards: + with pytest.raises(ValueError, match="only digits"): + PaymentCardNumber(card) + + def test_reject_wrong_length(self): + """Should reject card numbers with wrong length.""" + invalid_cards = [ + "123", # Too short + "12345678901234567890", # Too long + ] + for card in invalid_cards: + with pytest.raises(ValueError, match="must be"): + PaymentCardNumber(card) + + def test_card_repr_masking(self): + """Should mask card number in repr except last 4 digits.""" + card = PaymentCardNumber("4532015112830366") + card_repr = repr(card) + assert "0366" in card_repr # Last 4 visible + assert "4532" not in card_repr # First 4 masked + assert "*" in card_repr + + def test_card_type_error(self): + """Should reject non-string inputs.""" + with pytest.raises(TypeError): + PaymentCardNumber(123) # type: ignore + + +# ============================================================================== +# FilePath Tests +# ============================================================================== + + +class TestFilePath: + """Tests for FilePath type.""" + + def test_valid_file_path(self, tmp_path): + """Should accept existing file paths.""" + # Create a temporary file + test_file = tmp_path / "test.txt" + test_file.write_text("test content") + + result = FilePath(str(test_file)) + assert str(result) == str(test_file) + + def test_file_strips_whitespace(self, tmp_path): + """Should strip leading/trailing whitespace.""" + test_file = tmp_path / "test.txt" + test_file.write_text("test") + + result = FilePath(f" {test_file} ") + assert str(result) == str(test_file) + + def test_reject_nonexistent_file(self): + """Should reject paths that don't exist.""" + with pytest.raises(ValueError, match="does not exist"): + FilePath("/nonexistent/file.txt") + + def test_reject_directory(self, tmp_path): + """Should reject directory paths.""" + # tmp_path is a directory + with pytest.raises(ValueError, match="not a file"): + FilePath(str(tmp_path)) + + def test_file_type_error(self): + """Should reject non-string inputs.""" + with pytest.raises(TypeError): + FilePath(123) # type: ignore + + +# ============================================================================== +# DirectoryPath Tests +# ============================================================================== + + +class TestDirectoryPath: + """Tests for DirectoryPath type.""" + + def test_valid_directory_path(self, tmp_path): + """Should accept existing directory paths.""" + result = DirectoryPath(str(tmp_path)) + assert str(result) == str(tmp_path) + + def test_directory_strips_whitespace(self, tmp_path): + """Should strip leading/trailing whitespace.""" + result = DirectoryPath(f" {tmp_path} ") + assert str(result) == str(tmp_path) + + def test_reject_nonexistent_directory(self): + """Should reject paths that don't exist.""" + with pytest.raises(ValueError, match="does not exist"): + DirectoryPath("/nonexistent/directory") + + def test_reject_file(self, tmp_path): + """Should reject file paths.""" + test_file = tmp_path / "test.txt" + test_file.write_text("test") + + with pytest.raises(ValueError, match="not a directory"): + DirectoryPath(str(test_file)) + + def test_directory_type_error(self): + """Should reject non-string inputs.""" + with pytest.raises(TypeError): + DirectoryPath(123) # type: ignore