diff --git a/src/mcp/server/fastmcp/secure/__init__.py b/src/mcp/server/fastmcp/secure/__init__.py new file mode 100644 index 000000000..a20c932b8 --- /dev/null +++ b/src/mcp/server/fastmcp/secure/__init__.py @@ -0,0 +1,54 @@ +""" +Secure annotations and decorators for MCP tools, resources, and prompts. + +This module provides enhanced security features including: +- Bidirectional authentication (client ↔ tool) +- End-to-end encryption +- Tool attestation and signing +- Rate limiting and audit logging +""" + +from .annotations import ( + AuthMethod, + SecureAnnotations, + SecureToolAnnotations, + SecureResourceAnnotations, + SecurePromptAnnotations, +) +from .tool import SecureTool, secure_tool +from .resource import SecureResource, secure_resource +from .prompt import SecurePrompt, secure_prompt +from .identity import ToolIdentity, ClientIdentity, create_tool_identity +from .session import SecureSession, SessionManager +from .utils import SecureAnnotationProcessor, encrypt_data, decrypt_data + +__all__ = [ + # Annotations + "AuthMethod", + "SecureAnnotations", + "SecureToolAnnotations", + "SecureResourceAnnotations", + "SecurePromptAnnotations", + + # Secure wrappers + "SecureTool", + "SecureResource", + "SecurePrompt", + + # Decorators + "secure_tool", + "secure_resource", + "secure_prompt", + + # Identity & Session + "ToolIdentity", + "ClientIdentity", + "SecureSession", + "SessionManager", + "create_tool_identity", + + # Utils + "SecureAnnotationProcessor", + "encrypt_data", + "decrypt_data", +] \ No newline at end of file diff --git a/src/mcp/server/fastmcp/secure/annotations.py b/src/mcp/server/fastmcp/secure/annotations.py new file mode 100644 index 000000000..4c93eb36e --- /dev/null +++ b/src/mcp/server/fastmcp/secure/annotations.py @@ -0,0 +1,326 @@ +""" +Secure annotations for MCP tools, resources, and prompts. + +These annotations extend the standard MCP annotations with security features. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Optional + +from mcp.types import ToolAnnotations, ResourceAnnotations, PromptAnnotations + + +class AuthMethod(Enum): + """Supported authentication methods.""" + JWT = "jwt" + CERTIFICATE = "certificate" + TEE_ATTESTATION = "tee" + OAUTH = "oauth" + API_KEY = "api_key" + MTLS = "mtls" # Mutual TLS + + +@dataclass +class SecureAnnotations: + """ + Base security annotations that can be attached to tools, resources, or prompts. + + These annotations enable security features like authentication, encryption, + and attestation for MCP operations. + """ + + # Authentication settings + require_auth: bool = False + auth_methods: list[AuthMethod] = field(default_factory=lambda: [AuthMethod.JWT]) + required_permissions: set[str] = field(default_factory=set) + require_mutual_auth: bool = False # Bidirectional authentication + + # Encryption settings + encrypt_input: bool = False + encrypt_output: bool = False + encryption_algorithm: str = "AES-256-GCM" + key_exchange_method: str = "ECDH" # ECDH, RSA, Pre-shared + + # Tool/Server attestation + require_tool_attestation: bool = False + tool_certificate_fingerprint: Optional[str] = None + attestation_type: Optional[str] = None # "software", "sgx", "sev", "trustzone" + tool_signature_required: bool = False + + # Client verification + verify_client_certificate: bool = False + trusted_client_issuers: list[str] = field(default_factory=list) + client_attestation_required: bool = False + + # Audit and compliance + audit_log: bool = True + audit_include_inputs: bool = False + audit_include_outputs: bool = False + audit_retention_days: int = 90 + + # Rate limiting + rate_limit: Optional[int] = None # requests per minute + rate_limit_per_client: bool = True + burst_limit: Optional[int] = None + + # Data handling + security_level: str = "standard" # "standard", "high", "critical" + data_classification: str = "public" # "public", "internal", "confidential", "secret" + compliance_tags: list[str] = field(default_factory=list) # ["HIPAA", "PCI-DSS", "GDPR", "SOC2"] + + # Session management + session_timeout_minutes: int = 60 + require_session_binding: bool = False # Bind session to client IP/fingerprint + max_concurrent_sessions: Optional[int] = None + + # Advanced security + require_replay_protection: bool = False + max_request_age_seconds: int = 300 # For replay protection + require_integrity_check: bool = True # Verify message integrity + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "require_auth": self.require_auth, + "auth_methods": [method.value for method in self.auth_methods], + "required_permissions": list(self.required_permissions), + "require_mutual_auth": self.require_mutual_auth, + "encrypt_input": self.encrypt_input, + "encrypt_output": self.encrypt_output, + "encryption_algorithm": self.encryption_algorithm, + "key_exchange_method": self.key_exchange_method, + "require_tool_attestation": self.require_tool_attestation, + "tool_certificate_fingerprint": self.tool_certificate_fingerprint, + "attestation_type": self.attestation_type, + "tool_signature_required": self.tool_signature_required, + "verify_client_certificate": self.verify_client_certificate, + "trusted_client_issuers": self.trusted_client_issuers, + "client_attestation_required": self.client_attestation_required, + "audit_log": self.audit_log, + "audit_include_inputs": self.audit_include_inputs, + "audit_include_outputs": self.audit_include_outputs, + "security_level": self.security_level, + "data_classification": self.data_classification, + "compliance_tags": self.compliance_tags, + "session_timeout_minutes": self.session_timeout_minutes, + "require_session_binding": self.require_session_binding, + "require_replay_protection": self.require_replay_protection, + "require_integrity_check": self.require_integrity_check, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> SecureAnnotations: + """Create from dictionary.""" + auth_methods = [AuthMethod(m) for m in data.get("auth_methods", ["jwt"])] + return cls( + require_auth=data.get("require_auth", False), + auth_methods=auth_methods, + required_permissions=set(data.get("required_permissions", [])), + require_mutual_auth=data.get("require_mutual_auth", False), + encrypt_input=data.get("encrypt_input", False), + encrypt_output=data.get("encrypt_output", False), + encryption_algorithm=data.get("encryption_algorithm", "AES-256-GCM"), + key_exchange_method=data.get("key_exchange_method", "ECDH"), + require_tool_attestation=data.get("require_tool_attestation", False), + tool_certificate_fingerprint=data.get("tool_certificate_fingerprint"), + attestation_type=data.get("attestation_type"), + tool_signature_required=data.get("tool_signature_required", False), + verify_client_certificate=data.get("verify_client_certificate", False), + trusted_client_issuers=data.get("trusted_client_issuers", []), + client_attestation_required=data.get("client_attestation_required", False), + audit_log=data.get("audit_log", True), + security_level=data.get("security_level", "standard"), + data_classification=data.get("data_classification", "public"), + compliance_tags=data.get("compliance_tags", []), + ) + + +class SecureToolAnnotations(ToolAnnotations): + """ + Tool annotations with integrated security features. + + This extends the standard ToolAnnotations with security metadata. + """ + + def __init__( + self, + secure: Optional[SecureAnnotations] = None, + **kwargs + ): + super().__init__(**kwargs) + self.secure = secure or SecureAnnotations() + + # Store security annotations in extensions + if not hasattr(self, 'extensions'): + self.extensions = {} + self.extensions["security"] = self.secure.to_dict() + + @classmethod + def create( + cls, + # Security parameters + require_auth: bool = False, + auth_methods: Optional[list[AuthMethod]] = None, + required_permissions: Optional[set[str]] = None, + encrypt_io: bool = False, + require_mutual_auth: bool = False, + security_level: str = "standard", + + # Standard tool annotation parameters + audience: Optional[list[str]] = None, + capabilities: Optional[dict[str, Any]] = None, + **kwargs + ) -> SecureToolAnnotations: + """ + Factory method to create secure tool annotations. + + Args: + require_auth: Whether to require authentication + auth_methods: List of accepted authentication methods + required_permissions: Set of required permissions + encrypt_io: Whether to encrypt input/output + require_mutual_auth: Whether to require bidirectional authentication + security_level: Security level (standard/high/critical) + audience: Target audience for the tool + capabilities: Tool capabilities + **kwargs: Additional security parameters + """ + secure_annotations = SecureAnnotations( + require_auth=require_auth, + auth_methods=auth_methods or [AuthMethod.JWT], + required_permissions=required_permissions or set(), + encrypt_input=encrypt_io, + encrypt_output=encrypt_io, + require_mutual_auth=require_mutual_auth, + security_level=security_level, + **kwargs + ) + + return cls( + secure=secure_annotations, + audience=audience, + capabilities=capabilities + ) + + +class SecureResourceAnnotations(ResourceAnnotations): + """ + Resource annotations with integrated security features. + + This extends the standard ResourceAnnotations with security metadata. + """ + + def __init__( + self, + secure: Optional[SecureAnnotations] = None, + **kwargs + ): + super().__init__(**kwargs) + self.secure = secure or SecureAnnotations() + + # Store security annotations in extensions + if not hasattr(self, 'extensions'): + self.extensions = {} + self.extensions["security"] = self.secure.to_dict() + + @classmethod + def create( + cls, + # Security parameters + require_auth: bool = False, + data_classification: str = "public", + encrypt_io: bool = False, + audit_access: bool = True, + + # Standard resource annotation parameters + content_type: Optional[str] = None, + cache_control: Optional[str] = None, + **kwargs + ) -> SecureResourceAnnotations: + """ + Factory method to create secure resource annotations. + + Args: + require_auth: Whether to require authentication + data_classification: Data classification level + encrypt_io: Whether to encrypt input/output + audit_access: Whether to audit resource access + content_type: Resource content type + cache_control: Cache control headers + **kwargs: Additional security parameters + """ + secure_annotations = SecureAnnotations( + require_auth=require_auth, + data_classification=data_classification, + encrypt_input=encrypt_io, + encrypt_output=encrypt_io, + audit_log=audit_access, + **kwargs + ) + + return cls( + secure=secure_annotations, + content_type=content_type, + cache_control=cache_control + ) + + +class SecurePromptAnnotations(PromptAnnotations): + """ + Prompt annotations with integrated security features. + + This extends the standard PromptAnnotations with security metadata. + """ + + def __init__( + self, + secure: Optional[SecureAnnotations] = None, + **kwargs + ): + super().__init__(**kwargs) + self.secure = secure or SecureAnnotations() + + # Store security annotations in extensions + if not hasattr(self, 'extensions'): + self.extensions = {} + self.extensions["security"] = self.secure.to_dict() + + @classmethod + def create( + cls, + # Security parameters + require_auth: bool = False, + audit_usage: bool = True, + compliance_tags: Optional[list[str]] = None, + + # Standard prompt annotation parameters + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + **kwargs + ) -> SecurePromptAnnotations: + """ + Factory method to create secure prompt annotations. + + Args: + require_auth: Whether to require authentication + audit_usage: Whether to audit prompt usage + compliance_tags: Compliance tags (e.g., ["GDPR", "HIPAA"]) + max_tokens: Maximum tokens for prompt + temperature: Temperature for prompt generation + **kwargs: Additional security parameters + """ + secure_annotations = SecureAnnotations( + require_auth=require_auth, + audit_log=audit_usage, + compliance_tags=compliance_tags or [], + **kwargs + ) + + return cls( + secure=secure_annotations, + max_tokens=max_tokens, + temperature=temperature + ) diff --git a/src/mcp/server/fastmcp/secure/identity.py b/src/mcp/server/fastmcp/secure/identity.py new file mode 100644 index 000000000..3c29b838e --- /dev/null +++ b/src/mcp/server/fastmcp/secure/identity.py @@ -0,0 +1,390 @@ +""" +Identity management for secure MCP operations. + +Handles both tool identity (server-side) and client identity verification. +""" + +from __future__ import annotations + +import hashlib +import json +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import Any, Optional + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ec, rsa, padding +from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes, PublicKeyTypes +from cryptography.x509.oid import NameOID + +from .annotations import AuthMethod + + +@dataclass +class ToolIdentity: + """ + Represents the cryptographic identity of a tool/server. + + This is used for: + - Tool attestation (proving the tool is legitimate) + - Response signing (ensuring response integrity) + - Mutual authentication (bidirectional auth with client) + """ + + tool_id: str + name: str + version: str + certificate: x509.Certificate + private_key: PrivateKeyTypes + trusted_issuers: list[x509.Certificate] + + # Optional attestation for secure enclaves + attestation_report: Optional[dict] = None + attestation_type: Optional[str] = None # "software", "sgx", "sev", "trustzone" + + # Tool capabilities and metadata + capabilities: set[str] = field(default_factory=set) + metadata: dict[str, Any] = field(default_factory=dict) + + @property + def fingerprint(self) -> str: + """Get SHA256 fingerprint of the tool's certificate.""" + cert_der = self.certificate.public_bytes(serialization.Encoding.DER) + return hashlib.sha256(cert_der).hexdigest() + + @property + def public_key(self) -> PublicKeyTypes: + """Get the public key from the certificate.""" + return self.certificate.public_key() + + def sign_data(self, data: bytes) -> bytes: + """ + Sign data with the tool's private key. + + Args: + data: Data to sign + + Returns: + Digital signature + """ + if isinstance(self.private_key, rsa.RSAPrivateKey): + return self.private_key.sign( + data, + padding.PSS( + mgf=padding.MGF1(hashes.SHA256()), + salt_length=padding.PSS.MAX_LENGTH + ), + hashes.SHA256() + ) + elif isinstance(self.private_key, ec.EllipticCurvePrivateKey): + return self.private_key.sign(data, ec.ECDSA(hashes.SHA256())) + else: + raise ValueError(f"Unsupported key type: {type(self.private_key)}") + + def verify_signature(self, data: bytes, signature: bytes, public_key: PublicKeyTypes) -> bool: + """ + Verify a signature using a public key. + + Args: + data: Original data + signature: Signature to verify + public_key: Public key to verify with + + Returns: + True if signature is valid + """ + try: + if isinstance(public_key, rsa.RSAPublicKey): + public_key.verify( + signature, + data, + padding.PSS( + mgf=padding.MGF1(hashes.SHA256()), + salt_length=padding.PSS.MAX_LENGTH + ), + hashes.SHA256() + ) + elif isinstance(public_key, ec.EllipticCurvePublicKey): + public_key.verify(signature, data, ec.ECDSA(hashes.SHA256())) + else: + return False + return True + except Exception: + return False + + def to_attestation(self) -> dict[str, Any]: + """ + Generate attestation data for the tool. + + Returns: + Dictionary containing tool attestation information + """ + import base64 + + attestation = { + "tool_id": self.tool_id, + "name": self.name, + "version": self.version, + "fingerprint": self.fingerprint, + "certificate": base64.b64encode( + self.certificate.public_bytes(serialization.Encoding.PEM) + ).decode(), + "capabilities": list(self.capabilities), + "timestamp": datetime.utcnow().isoformat(), + } + + # Add hardware attestation if available + if self.attestation_report: + attestation["attestation"] = { + "type": self.attestation_type, + "report": self.attestation_report, + } + + # Sign the attestation + attestation_bytes = json.dumps(attestation, sort_keys=True).encode() + signature = self.sign_data(attestation_bytes) + attestation["signature"] = base64.b64encode(signature).decode() + + return attestation + + def verify_client_signature(self, data: bytes, signature: bytes, client_cert: x509.Certificate) -> bool: + """ + Verify a signature from a client certificate. + + Args: + data: Data that was signed + signature: Client's signature + client_cert: Client's certificate + + Returns: + True if signature is valid + """ + client_public_key = client_cert.public_key() + return self.verify_signature(data, signature, client_public_key) + + +@dataclass +class ClientIdentity: + """ + Represents an authenticated client identity. + + This is created after successful authentication and contains + the client's permissions and metadata. + """ + + client_id: str + authentication_method: AuthMethod + credentials: Any # JWT token, certificate, attestation, etc. + permissions: set[str] + + # Optional fields + session_id: Optional[str] = None + organization: Optional[str] = None + email: Optional[str] = None + + # Metadata + metadata: dict[str, Any] = field(default_factory=dict) + authenticated_at: datetime = field(default_factory=datetime.utcnow) + expires_at: Optional[datetime] = None + + # Certificate-based auth specifics + certificate: Optional[x509.Certificate] = None + certificate_fingerprint: Optional[str] = None + + # Rate limiting and quotas + rate_limit: Optional[int] = None + quota_remaining: Optional[int] = None + + def has_permission(self, permission: str) -> bool: + """ + Check if client has a specific permission. + + Args: + permission: Permission to check (e.g., "tool.execute", "resource.read") + + Returns: + True if client has the permission + """ + # Check exact permission + if permission in self.permissions: + return True + + # Check wildcard permissions + if "*" in self.permissions: + return True + + # Check hierarchical permissions (e.g., "tool.*" matches "tool.execute") + parts = permission.split(".") + for i in range(len(parts)): + wildcard_perm = ".".join(parts[:i+1]) + ".*" + if wildcard_perm in self.permissions: + return True + + return False + + def is_expired(self) -> bool: + """Check if the client identity has expired.""" + if self.expires_at is None: + return False + return datetime.utcnow() > self.expires_at + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "client_id": self.client_id, + "authentication_method": self.authentication_method.value, + "permissions": list(self.permissions), + "organization": self.organization, + "email": self.email, + "authenticated_at": self.authenticated_at.isoformat(), + "expires_at": self.expires_at.isoformat() if self.expires_at else None, + "certificate_fingerprint": self.certificate_fingerprint, + "metadata": self.metadata, + } + + +def create_tool_identity( + tool_name: str, + tool_version: str, + organization: str = "MCP-Secure", + country: str = "US", + validity_days: int = 365, + key_type: str = "EC" # "EC" or "RSA" +) -> ToolIdentity: + """ + Create a tool identity with a self-signed certificate. + + In production, you would use a proper CA-signed certificate. + + Args: + tool_name: Name of the tool + tool_version: Version of the tool + organization: Organization name + country: Country code + validity_days: Certificate validity period + key_type: Key type ("EC" for elliptic curve, "RSA" for RSA) + + Returns: + ToolIdentity with generated certificate and key + """ + # Generate key pair + if key_type == "RSA": + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + else: # EC + private_key = ec.generate_private_key(ec.SECP256R1()) + + # Create certificate + subject = issuer = x509.Name([ + x509.NameAttribute(NameOID.COUNTRY_NAME, country), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, organization), + x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "MCP-Tools"), + x509.NameAttribute(NameOID.COMMON_NAME, f"{tool_name}-v{tool_version}"), + ]) + + # Build certificate + builder = x509.CertificateBuilder() + builder = builder.subject_name(subject) + builder = builder.issuer_name(issuer) + builder = builder.public_key(private_key.public_key()) + builder = builder.serial_number(x509.random_serial_number()) + builder = builder.not_valid_before(datetime.utcnow()) + builder = builder.not_valid_after(datetime.utcnow() + timedelta(days=validity_days)) + + # Add extensions + builder = builder.add_extension( + x509.SubjectAlternativeName([ + x509.DNSName(f"{tool_name}.local"), + x509.DNSName("localhost"), + ]), + critical=False, + ) + + builder = builder.add_extension( + x509.KeyUsage( + digital_signature=True, + key_encipherment=True, + content_commitment=True, + data_encipherment=False, + key_agreement=True, + key_cert_sign=False, + crl_sign=False, + encipher_only=False, + decipher_only=False, + ), + critical=True, + ) + + builder = builder.add_extension( + x509.ExtendedKeyUsage([ + x509.oid.ExtendedKeyUsageOID.SERVER_AUTH, + x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH, + ]), + critical=True, + ) + + # Self-sign the certificate + certificate = builder.sign(private_key, hashes.SHA256()) + + return ToolIdentity( + tool_id=f"{tool_name}-{tool_version}", + name=tool_name, + version=tool_version, + certificate=certificate, + private_key=private_key, + trusted_issuers=[certificate], # Self-signed + capabilities={ + "authentication.mutual", + "encryption.aes256", + "signing.sha256", + }, + metadata={ + "created_at": datetime.utcnow().isoformat(), + "key_type": key_type, + "organization": organization, + } + ) + + +def verify_tool_certificate( + certificate: x509.Certificate, + trusted_cas: list[x509.Certificate], + check_revocation: bool = True +) -> tuple[bool, Optional[str]]: + """ + Verify a tool's certificate against trusted CAs. + + Args: + certificate: Certificate to verify + trusted_cas: List of trusted CA certificates + check_revocation: Whether to check certificate revocation + + Returns: + Tuple of (is_valid, error_message) + """ + # Check certificate validity period + now = datetime.utcnow() + if now < certificate.not_valid_before: + return False, "Certificate not yet valid" + if now > certificate.not_valid_after: + return False, "Certificate has expired" + + # Verify certificate chain + for ca in trusted_cas: + try: + ca.public_key().verify( + certificate.signature, + certificate.tbs_certificate_bytes, + certificate.signature_algorithm_oid._name + ) + + # If we reach here, signature is valid + if check_revocation: + # In production, check CRL or OCSP + pass + + return True, None + except Exception: + continue + + return False, "Certificate not signed by trusted CA" \ No newline at end of file diff --git a/src/mcp/server/fastmcp/secure/prompt.py b/src/mcp/server/fastmcp/secure/prompt.py new file mode 100644 index 000000000..8956f1108 --- /dev/null +++ b/src/mcp/server/fastmcp/secure/prompt.py @@ -0,0 +1,380 @@ +""" +Secure prompt implementation with authentication and compliance. +""" + +from __future__ import annotations + +import functools +import hashlib +import inspect +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, cast + +from mcp.server.fastmcp.prompts.base import Prompt +from mcp.types import Error, Message + +from .annotations import AuthMethod, SecureAnnotations, SecurePromptAnnotations +from .identity import ToolIdentity +from .utils import SecureAnnotationProcessor + +if TYPE_CHECKING: + from mcp.server.fastmcp.server import Context, FastMCP + +F = TypeVar('F', bound=Callable[..., Any]) + + +class SecurePrompt(Prompt): + """ + Secure prompt with authentication, compliance, and audit support. + + This extends the base Prompt class with security features for + handling sensitive prompts and ensuring compliance. + """ + + def __init__( + self, + secure_annotations: SecureAnnotations, + tool_identity: Optional[ToolIdentity] = None, + processor: Optional[SecureAnnotationProcessor] = None, + **kwargs + ): + super().__init__(**kwargs) + self.secure_annotations = secure_annotations + self.tool_identity = tool_identity + self.processor = processor or SecureAnnotationProcessor(tool_identity=tool_identity) + + # Add compliance metadata + self._compliance_metadata = { + "compliance_tags": secure_annotations.compliance_tags, + "data_classification": secure_annotations.data_classification, + "audit_required": secure_annotations.audit_log, + } + + async def render(self, arguments: dict[str, Any] | None = None) -> list[Message]: + """ + Render the secure prompt with authentication and compliance checks. + + This method: + 1. Verifies client authentication (if required) + 2. Checks compliance requirements + 3. Sanitizes/filters sensitive information + 4. Renders the prompt + 5. Audits the usage + """ + arguments = arguments or {} + + # In production, extract auth from request context + auth_header = None # Would come from request context + + # Process secure request + try: + session, processed_args = await self.processor.process_secure_request( + annotations=self.secure_annotations, + auth_header=auth_header, + client_cert=None, + request_data=arguments + ) + except Error as e: + # Audit failed access attempt + self.processor._audit_log( + session=None, + action="prompt_access_denied", + include_data=False, + data={"prompt": self.name, "error": str(e)} + ) + raise + + # Check compliance requirements + if self.secure_annotations.compliance_tags: + await self._check_compliance(session, processed_args) + + # Sanitize sensitive information if needed + if self.secure_annotations.data_classification in ["confidential", "secret"]: + processed_args = await self._sanitize_arguments(processed_args) + + # Render the actual prompt + messages = await super().render(processed_args) + + # Post-process messages for security + secure_messages = await self._secure_messages(messages, session) + + # Audit prompt usage + if self.secure_annotations.audit_log: + await self._audit_prompt_usage(session, processed_args, secure_messages) + + return secure_messages + + async def _check_compliance(self, session, arguments: dict) -> None: + """ + Check compliance requirements before rendering the prompt. + + Args: + session: Secure session + arguments: Prompt arguments + + Raises: + Error: If compliance requirements are not met + """ + for tag in self.secure_annotations.compliance_tags: + if tag == "GDPR": + # Check GDPR compliance (e.g., purpose limitation, data minimization) + if "personal_data" in arguments and not session.client_identity.has_permission("gdpr.process"): + raise Error(code=403, message="GDPR: Missing permission to process personal data") + + elif tag == "HIPAA": + # Check HIPAA compliance for health information + if "health_data" in arguments and not session.client_identity.has_permission("hipaa.access"): + raise Error(code=403, message="HIPAA: Not authorized to access health information") + + elif tag == "PCI-DSS": + # Check PCI-DSS compliance for payment card data + if "card_data" in arguments: + # Ensure card data is masked/tokenized + if not self._is_card_data_safe(arguments["card_data"]): + raise Error(code=400, message="PCI-DSS: Card data must be tokenized") + + async def _sanitize_arguments(self, arguments: dict) -> dict: + """ + Sanitize sensitive information from arguments. + + Args: + arguments: Original arguments + + Returns: + Sanitized arguments + """ + sanitized = {} + for key, value in arguments.items(): + if key in ["ssn", "credit_card", "password", "api_key"]: + # Mask sensitive fields + sanitized[key] = self._mask_sensitive_data(str(value)) + elif isinstance(value, dict): + # Recursively sanitize nested data + sanitized[key] = await self._sanitize_arguments(value) + else: + sanitized[key] = value + + return sanitized + + def _mask_sensitive_data(self, data: str) -> str: + """Mask sensitive data while preserving format hints.""" + if len(data) <= 4: + return "*" * len(data) + + # Show first and last 2 characters only + return data[:2] + "*" * (len(data) - 4) + data[-2:] + + def _is_card_data_safe(self, card_data: str) -> bool: + """Check if card data is properly tokenized/masked.""" + # Check if it's a token (e.g., tok_xxxx) or masked number + return card_data.startswith("tok_") or "*" in card_data + + async def _secure_messages(self, messages: list[Message], session) -> list[Message]: + """ + Apply security transformations to messages. + + Args: + messages: Original messages + session: Secure session + + Returns: + Secured messages + """ + secure_msgs = [] + + for msg in messages: + secure_msg = msg.copy() if hasattr(msg, 'copy') else msg + + # Add security headers to system messages + if isinstance(msg, dict) and msg.get("role") == "system": + if self.secure_annotations.compliance_tags: + compliance_notice = f"[Compliance: {', '.join(self.secure_annotations.compliance_tags)}] " + secure_msg["content"] = compliance_notice + secure_msg.get("content", "") + + # Add classification labels + if self.secure_annotations.data_classification != "public": + if isinstance(secure_msg, dict): + secure_msg["metadata"] = secure_msg.get("metadata", {}) + secure_msg["metadata"]["classification"] = self.secure_annotations.data_classification + + secure_msgs.append(secure_msg) + + return secure_msgs + + async def _audit_prompt_usage(self, session, arguments: dict, messages: list) -> None: + """ + Audit prompt usage for compliance and security monitoring. + + Args: + session: Secure session + arguments: Prompt arguments + messages: Generated messages + """ + audit_data = { + "prompt_name": self.name, + "client_id": session.client_identity.client_id if session.client_identity else "anonymous", + "compliance_tags": self.secure_annotations.compliance_tags, + "data_classification": self.secure_annotations.data_classification, + "message_count": len(messages), + } + + if self.secure_annotations.audit_include_inputs: + # Hash sensitive arguments for audit + audit_data["argument_hash"] = hashlib.sha256( + str(arguments).encode() + ).hexdigest() + + if self.secure_annotations.audit_include_outputs: + # Include message metadata (not content) + audit_data["message_roles"] = [ + msg.get("role") if isinstance(msg, dict) else "unknown" + for msg in messages + ] + + self.processor._audit_log( + session=session, + action="prompt_rendered", + include_data=True, + data=audit_data + ) + + +def secure_prompt( + # Security parameters + require_auth: bool = False, + audit_usage: bool = True, + compliance_tags: Optional[list[str]] = None, + data_classification: str = "public", + + # Standard prompt parameters + name: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, +) -> Callable[[F], F]: + """ + Decorator to create a secure prompt with compliance and audit support. + + This decorator wraps a function to create a secure MCP prompt that supports: + - Authentication and authorization + - Compliance checking (GDPR, HIPAA, PCI-DSS, etc.) + - Sensitive data sanitization + - Usage auditing + + Args: + require_auth: Whether to require authentication + audit_usage: Whether to audit prompt usage + compliance_tags: Compliance requirements (e.g., ["GDPR", "HIPAA"]) + data_classification: Data classification level + name: Prompt name + title: Prompt title + description: Prompt description + + Example: + ```python + @secure_prompt( + require_auth=True, + compliance_tags=["GDPR", "HIPAA"], + data_classification="confidential", + audit_usage=True + ) + async def medical_diagnosis_prompt( + patient_id: str, + symptoms: list[str], + ctx: Context + ) -> list[Message]: + # Ensure HIPAA compliance + return [ + { + "role": "system", + "content": "You are a medical AI assistant. Maintain patient confidentiality." + }, + { + "role": "user", + "content": f"Analyze symptoms for patient (ID: {patient_id}): {symptoms}" + } + ] + ``` + """ + # Create secure annotations + secure_annotations = SecureAnnotations( + require_auth=require_auth, + audit_log=audit_usage, + compliance_tags=compliance_tags or [], + data_classification=data_classification, + ) + + def decorator(func: F) -> F: + @functools.wraps(func) + async def wrapper(*args, **kwargs): + # This wrapper would be replaced by SecurePrompt.render in production + result = await func(*args, **kwargs) if inspect.iscoroutinefunction(func) else func(*args, **kwargs) + return result + + # Store security metadata on the function + wrapper._secure_annotations = secure_annotations + wrapper._is_secure_prompt = True + wrapper._compliance_tags = compliance_tags or [] + wrapper._data_classification = data_classification + + return cast(F, wrapper) + + return decorator + + +class ComplianceValidator: + """ + Validator for ensuring prompts meet compliance requirements. + """ + + @staticmethod + def validate_gdpr(prompt_content: str, metadata: dict) -> tuple[bool, Optional[str]]: + """ + Validate GDPR compliance for a prompt. + + Returns: + Tuple of (is_valid, error_message) + """ + # Check for purpose limitation + if "purpose" not in metadata: + return False, "GDPR requires explicit purpose declaration" + + # Check for data minimization + sensitive_keywords = ["ssn", "email", "phone", "address", "name"] + if any(keyword in prompt_content.lower() for keyword in sensitive_keywords): + if "legal_basis" not in metadata: + return False, "GDPR requires legal basis for processing personal data" + + return True, None + + @staticmethod + def validate_hipaa(prompt_content: str, metadata: dict) -> tuple[bool, Optional[str]]: + """ + Validate HIPAA compliance for a prompt. + + Returns: + Tuple of (is_valid, error_message) + """ + # Check for PHI safeguards + phi_keywords = ["patient", "diagnosis", "treatment", "medical", "health"] + if any(keyword in prompt_content.lower() for keyword in phi_keywords): + if "hipaa_safeguards" not in metadata: + return False, "HIPAA requires safeguards for Protected Health Information" + + return True, None + + @staticmethod + def validate_pci_dss(prompt_content: str, metadata: dict) -> tuple[bool, Optional[str]]: + """ + Validate PCI-DSS compliance for a prompt. + + Returns: + Tuple of (is_valid, error_message) + """ + # Check for credit card data + import re + + # Simple regex for credit card patterns (not comprehensive) + cc_pattern = r'\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b' + if re.search(cc_pattern, prompt_content): + return False, "PCI-DSS prohibits storage of unencrypted card numbers" + + return True, None \ No newline at end of file diff --git a/src/mcp/server/fastmcp/secure/resource.py b/src/mcp/server/fastmcp/secure/resource.py new file mode 100644 index 000000000..23e56b4ef --- /dev/null +++ b/src/mcp/server/fastmcp/secure/resource.py @@ -0,0 +1,270 @@ +""" +Secure resource implementation with authentication and encryption. +""" + +from __future__ import annotations + +import functools +import inspect +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, cast + +from mcp.server.fastmcp.resources.base import Resource +from mcp.types import Error + +from .annotations import AuthMethod, SecureAnnotations, SecureResourceAnnotations +from .identity import ToolIdentity +from .utils import SecureAnnotationProcessor + +if TYPE_CHECKING: + from mcp.server.fastmcp.server import FastMCP + +F = TypeVar('F', bound=Callable[..., Any]) + + +class SecureResource(Resource): + """ + Secure resource with authentication, encryption, and access control. + + This extends the base Resource class with security features. + """ + + def __init__( + self, + secure_annotations: SecureAnnotations, + tool_identity: Optional[ToolIdentity] = None, + processor: Optional[SecureAnnotationProcessor] = None, + **kwargs + ): + super().__init__(**kwargs) + self.secure_annotations = secure_annotations + self.tool_identity = tool_identity + self.processor = processor or SecureAnnotationProcessor(tool_identity=tool_identity) + + # Add security metadata to the resource + self._secure_metadata = { + "data_classification": secure_annotations.data_classification, + "encryption_required": secure_annotations.encrypt_output, + "auth_required": secure_annotations.require_auth, + "compliance_tags": secure_annotations.compliance_tags, + } + + async def read(self) -> str | bytes: + """ + Read the secure resource with authentication and encryption. + + This method: + 1. Verifies client authentication (if required) + 2. Checks access permissions + 3. Reads the resource + 4. Encrypts the content (if required) + 5. Audits the access + """ + # In production, extract auth from request context + auth_header = None # Would come from request context + + # Process secure request + try: + session, _ = await self.processor.process_secure_request( + annotations=self.secure_annotations, + auth_header=auth_header, + client_cert=None, + request_data={} + ) + except Error as e: + # Audit failed access attempt + self.processor._audit_log( + session=None, + action="resource_access_denied", + include_data=False, + data={"uri": self.uri, "error": str(e)} + ) + raise + + # Check specific resource permissions + if session.client_identity: + resource_permission = f"resource.read.{self.name or self.uri}" + if not session.client_identity.has_permission(resource_permission) and \ + not session.client_identity.has_permission("resource.read.*"): + raise Error( + code=403, + message=f"Client lacks permission to read resource: {self.uri}" + ) + + # Read the actual resource content + content = await super().read() + + # Process secure response (encrypt if required) + if self.secure_annotations.encrypt_output: + secure_content = await self.processor.process_secure_response( + annotations=self.secure_annotations, + session=session, + response_data=content + ) + + # Convert encrypted response to string/bytes + if isinstance(secure_content, dict) and "data" in secure_content: + content = secure_content["data"] + + # Audit successful access + if self.secure_annotations.audit_log: + self.processor._audit_log( + session=session, + action="resource_accessed", + include_data=self.secure_annotations.audit_include_outputs, + data={ + "uri": self.uri, + "classification": self.secure_annotations.data_classification, + "size": len(content) if isinstance(content, (str, bytes)) else None + } + ) + + return content + + +def secure_resource( + uri: str, + # Security parameters + require_auth: bool = False, + data_classification: str = "public", + encrypt_io: bool = False, + audit_access: bool = True, + compliance_tags: Optional[list[str]] = None, + + # Standard resource parameters + name: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + mime_type: Optional[str] = None, +) -> Callable[[F], F]: + """ + Decorator to create a secure resource with authentication and encryption. + + This decorator wraps a function to create a secure MCP resource that supports: + - Access control and authentication + - Data classification and compliance + - Encryption for sensitive data + - Audit logging + + Args: + uri: Resource URI + require_auth: Whether to require authentication + data_classification: Classification level (public/internal/confidential/secret) + encrypt_io: Whether to encrypt the resource content + audit_access: Whether to audit resource access + compliance_tags: Compliance tags (e.g., ["GDPR", "HIPAA"]) + name: Resource name + title: Resource title + description: Resource description + mime_type: MIME type + + Example: + ```python + @secure_resource( + "secure://financial/portfolio/{account_id}", + require_auth=True, + data_classification="confidential", + encrypt_io=True, + compliance_tags=["PCI-DSS", "SOC2"] + ) + async def get_portfolio(account_id: str) -> dict: + # Resource implementation + return { + "account_id": account_id, + "balance": 100000, + "holdings": [...] + } + ``` + """ + # Create secure annotations + secure_annotations = SecureAnnotations( + require_auth=require_auth, + data_classification=data_classification, + encrypt_input=encrypt_io, + encrypt_output=encrypt_io, + audit_log=audit_access, + compliance_tags=compliance_tags or [], + ) + + def decorator(func: F) -> F: + @functools.wraps(func) + async def wrapper(*args, **kwargs): + # This wrapper would be replaced by SecureResource.read in production + result = await func(*args, **kwargs) if inspect.iscoroutinefunction(func) else func(*args, **kwargs) + return result + + # Store security metadata on the function + wrapper._secure_annotations = secure_annotations + wrapper._resource_uri = uri + wrapper._is_secure_resource = True + wrapper._data_classification = data_classification + wrapper._compliance_tags = compliance_tags or [] + + return cast(F, wrapper) + + return decorator + + +class SecureResourceTemplate: + """ + Template for secure resources with dynamic URIs. + + Supports resources like "secure://data/{category}/{item_id}" + """ + + def __init__( + self, + uri_template: str, + secure_annotations: SecureAnnotations, + tool_identity: Optional[ToolIdentity] = None, + ): + self.uri_template = uri_template + self.secure_annotations = secure_annotations + self.tool_identity = tool_identity + self.processor = SecureAnnotationProcessor(tool_identity=tool_identity) + + def create_resource(self, **params) -> SecureResource: + """ + Create a secure resource instance with the given parameters. + + Args: + **params: Parameters to fill in the URI template + + Returns: + SecureResource instance + """ + # Format the URI with parameters + uri = self.uri_template.format(**params) + + return SecureResource( + uri=uri, + secure_annotations=self.secure_annotations, + tool_identity=self.tool_identity, + processor=self.processor, + ) + + def validate_access(self, client_identity, params: dict) -> bool: + """ + Validate if a client can access a resource with given parameters. + + Args: + client_identity: Client identity to validate + params: Resource parameters + + Returns: + True if access is allowed, False otherwise + """ + # Check base permissions + if not client_identity.has_permission(f"resource.read.{self.uri_template}"): + return False + + # Check parameter-specific permissions + # For example, for "secure://portfolio/{account_id}", + # check if client can access that specific account + for param_name, param_value in params.items(): + specific_perm = f"resource.{param_name}.{param_value}" + if not client_identity.has_permission(specific_perm): + # Check wildcard permission + if not client_identity.has_permission(f"resource.{param_name}.*"): + return False + + return True \ No newline at end of file diff --git a/src/mcp/server/fastmcp/secure/session.py b/src/mcp/server/fastmcp/secure/session.py new file mode 100644 index 000000000..2f474d8f0 --- /dev/null +++ b/src/mcp/server/fastmcp/secure/session.py @@ -0,0 +1,473 @@ +""" +Session management for secure MCP operations. + +Handles secure session establishment, key exchange, and session lifecycle. +""" + +from __future__ import annotations + +import base64 +import hashlib +import hmac +import os +import secrets +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import Any, Dict, Optional + +from cryptography.hazmat.primitives.ciphers.aead import AESGCM, ChaCha20Poly1305 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ec, dh +from cryptography.hazmat.primitives.kdf.hkdf import HKDF +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2 +from cryptography.hazmat.primitives.asymmetric.types import PublicKeyTypes + +from .identity import ClientIdentity, ToolIdentity + + +@dataclass +class SecureSession: + """ + Represents a secure session between a client and tool. + + Supports: + - Mutual authentication + - Key exchange and encryption + - Session binding and replay protection + """ + + session_id: str + client_identity: Optional[ClientIdentity] + tool_identity: Optional[ToolIdentity] + established_at: datetime + expires_at: datetime + + # Encryption + encryption_algorithm: str = "AES-256-GCM" # or "ChaCha20-Poly1305" + encryption_key: Optional[AESGCM | ChaCha20Poly1305] = None + client_public_key: Optional[PublicKeyTypes] = None + server_public_key: Optional[PublicKeyTypes] = None + + # Session binding + client_ip: Optional[str] = None + client_fingerprint: Optional[str] = None + bound_to_client: bool = False + + # Replay protection + nonce_counter: int = 0 + used_nonces: set[str] = field(default_factory=set) + max_nonce_age_seconds: int = 300 + + # Rate limiting + request_count: int = 0 + last_request_at: Optional[datetime] = None + + # Metadata + metadata: dict[str, Any] = field(default_factory=dict) + + def is_valid(self) -> bool: + """Check if session is still valid.""" + now = datetime.utcnow() + + # Check expiration + if now > self.expires_at: + return False + + # Check client identity expiration + if self.client_identity and self.client_identity.is_expired(): + return False + + return True + + def is_bound_to(self, client_ip: str, client_fingerprint: str) -> bool: + """ + Check if session is bound to the requesting client. + + Args: + client_ip: Client IP address + client_fingerprint: Client fingerprint (e.g., TLS fingerprint) + + Returns: + True if session binding matches + """ + if not self.bound_to_client: + return True + + if self.client_ip and self.client_ip != client_ip: + return False + + if self.client_fingerprint and self.client_fingerprint != client_fingerprint: + return False + + return True + + def encrypt(self, data: bytes, associated_data: Optional[bytes] = None) -> bytes: + """ + Encrypt data using session key. + + Args: + data: Data to encrypt + associated_data: Additional authenticated data + + Returns: + Encrypted data with nonce prepended + """ + if not self.encryption_key: + raise ValueError("No encryption key established") + + nonce = os.urandom(12) # 96-bit nonce for AES-GCM + ciphertext = self.encryption_key.encrypt(nonce, data, associated_data) + + return nonce + ciphertext + + def decrypt( + self, + encrypted_data: bytes, + associated_data: Optional[bytes] = None + ) -> bytes: + """ + Decrypt data using session key. + + Args: + encrypted_data: Encrypted data with nonce prepended + associated_data: Additional authenticated data + + Returns: + Decrypted data + """ + if not self.encryption_key: + raise ValueError("No encryption key established") + + nonce, ciphertext = encrypted_data[:12], encrypted_data[12:] + + # Check for nonce reuse (replay protection) + nonce_b64 = base64.b64encode(nonce).decode() + if nonce_b64 in self.used_nonces: + raise ValueError("Nonce reuse detected - possible replay attack") + + self.used_nonces.add(nonce_b64) + self.nonce_counter += 1 + + return self.encryption_key.decrypt(nonce, ciphertext, associated_data) + + def generate_request_token(self) -> str: + """ + Generate a request token for replay protection. + + Returns: + Base64-encoded request token + """ + timestamp = datetime.utcnow().isoformat() + nonce = secrets.token_bytes(16) + + token_data = f"{self.session_id}:{timestamp}:{base64.b64encode(nonce).decode()}" + + # Sign the token + if self.encryption_key and isinstance(self.encryption_key, AESGCM): + # Use HMAC with part of the session key + key_bytes = self.encryption_key._key[:16] # Use first 16 bytes for HMAC + signature = hmac.new(key_bytes, token_data.encode(), hashlib.sha256).digest() + + return base64.b64encode( + token_data.encode() + signature + ).decode() + + return base64.b64encode(token_data.encode()).decode() + + def verify_request_token(self, token: str) -> bool: + """ + Verify a request token for replay protection. + + Args: + token: Request token to verify + + Returns: + True if token is valid and fresh + """ + try: + decoded = base64.b64decode(token) + + if self.encryption_key and isinstance(self.encryption_key, AESGCM): + # Split token and signature + token_data = decoded[:-32] + signature = decoded[-32:] + + # Verify signature + key_bytes = self.encryption_key._key[:16] + expected_sig = hmac.new(key_bytes, token_data, hashlib.sha256).digest() + + if not hmac.compare_digest(signature, expected_sig): + return False + else: + token_data = decoded + + # Parse token + parts = token_data.decode().split(":") + if len(parts) != 3: + return False + + session_id, timestamp_str, nonce_b64 = parts + + # Verify session ID + if session_id != self.session_id: + return False + + # Check timestamp freshness + timestamp = datetime.fromisoformat(timestamp_str) + age = (datetime.utcnow() - timestamp).total_seconds() + + if age > self.max_nonce_age_seconds: + return False + + # Check nonce uniqueness + if nonce_b64 in self.used_nonces: + return False + + self.used_nonces.add(nonce_b64) + + return True + + except Exception: + return False + + def rotate_session_key(self) -> None: + """Rotate the session encryption key.""" + if not self.encryption_key: + return + + # Derive new key from old key + if isinstance(self.encryption_key, AESGCM): + old_key = self.encryption_key._key + + # Use HKDF to derive new key + hkdf = HKDF( + algorithm=hashes.SHA256(), + length=32, + salt=None, + info=b'session-key-rotation', + ) + new_key = hkdf.derive(old_key + self.session_id.encode()) + + self.encryption_key = AESGCM(new_key) + + # Clear nonce history on key rotation + self.used_nonces.clear() + self.nonce_counter = 0 + + +class SessionManager: + """ + Manages secure sessions for MCP operations. + """ + + def __init__( + self, + tool_identity: Optional[ToolIdentity] = None, + session_timeout_minutes: int = 60, + max_sessions_per_client: int = 10, + ): + self.tool_identity = tool_identity + self.session_timeout_minutes = session_timeout_minutes + self.max_sessions_per_client = max_sessions_per_client + + # Session storage + self.sessions: Dict[str, SecureSession] = {} + self.client_sessions: Dict[str, list[str]] = {} # client_id -> [session_ids] + + # DH parameters for key exchange + self._dh_parameters = None + self._ecdh_curve = ec.SECP256R1() + + def create_session( + self, + client_identity: Optional[ClientIdentity] = None, + encryption_algorithm: str = "AES-256-GCM", + bind_to_client: bool = False, + client_ip: Optional[str] = None, + client_fingerprint: Optional[str] = None, + ) -> SecureSession: + """ + Create a new secure session. + + Args: + client_identity: Authenticated client identity + encryption_algorithm: Encryption algorithm to use + bind_to_client: Whether to bind session to client + client_ip: Client IP for session binding + client_fingerprint: Client fingerprint for session binding + + Returns: + New SecureSession instance + """ + # Generate session ID + session_id = base64.urlsafe_b64encode(os.urandom(32)).decode().rstrip("=") + + # Check session limit per client + if client_identity: + client_id = client_identity.client_id + if client_id in self.client_sessions: + if len(self.client_sessions[client_id]) >= self.max_sessions_per_client: + # Remove oldest session + oldest_session_id = self.client_sessions[client_id][0] + self.revoke_session(oldest_session_id) + + # Create session + session = SecureSession( + session_id=session_id, + client_identity=client_identity, + tool_identity=self.tool_identity, + established_at=datetime.utcnow(), + expires_at=datetime.utcnow() + timedelta(minutes=self.session_timeout_minutes), + encryption_algorithm=encryption_algorithm, + bound_to_client=bind_to_client, + client_ip=client_ip, + client_fingerprint=client_fingerprint, + ) + + # Store session + self.sessions[session_id] = session + + # Track client sessions + if client_identity: + client_id = client_identity.client_id + if client_id not in self.client_sessions: + self.client_sessions[client_id] = [] + self.client_sessions[client_id].append(session_id) + + return session + + def get_session(self, session_id: str) -> Optional[SecureSession]: + """ + Get a session by ID. + + Args: + session_id: Session ID + + Returns: + SecureSession if found and valid + """ + session = self.sessions.get(session_id) + + if session and session.is_valid(): + return session + + # Remove invalid session + if session: + self.revoke_session(session_id) + + return None + + def revoke_session(self, session_id: str) -> None: + """ + Revoke a session. + + Args: + session_id: Session ID to revoke + """ + session = self.sessions.pop(session_id, None) + + if session and session.client_identity: + # Remove from client sessions + client_id = session.client_identity.client_id + if client_id in self.client_sessions: + self.client_sessions[client_id] = [ + sid for sid in self.client_sessions[client_id] + if sid != session_id + ] + + def perform_ecdh_key_exchange( + self, + session: SecureSession, + client_public_key_pem: bytes + ) -> bytes: + """ + Perform ECDH key exchange to establish session key. + + Args: + session: Session to establish key for + client_public_key_pem: Client's public key in PEM format + + Returns: + Server's public key in PEM format + """ + # Generate server's ephemeral key pair + server_private_key = ec.generate_private_key(self._ecdh_curve) + server_public_key = server_private_key.public_key() + + # Load client's public key + client_public_key = serialization.load_pem_public_key(client_public_key_pem) + + # Perform ECDH to get shared secret + shared_secret = server_private_key.exchange( + ec.ECDH(), + client_public_key + ) + + # Derive session key using HKDF + hkdf = HKDF( + algorithm=hashes.SHA256(), + length=32, # 256-bit key + salt=session.session_id.encode()[:16], # Use session ID as salt + info=b'mcp-session-key', + ) + session_key = hkdf.derive(shared_secret) + + # Create cipher based on algorithm + if session.encryption_algorithm == "ChaCha20-Poly1305": + session.encryption_key = ChaCha20Poly1305(session_key) + else: # Default to AES-256-GCM + session.encryption_key = AESGCM(session_key) + + # Store public keys + session.client_public_key = client_public_key + session.server_public_key = server_public_key + + # Return server's public key + return server_public_key.public_key_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo + ) + + def establish_pre_shared_key( + self, + session: SecureSession, + pre_shared_key: bytes + ) -> None: + """ + Establish session key from pre-shared key. + + Args: + session: Session to establish key for + pre_shared_key: Pre-shared key + """ + # Derive session key from PSK using PBKDF2 + kdf = PBKDF2( + algorithm=hashes.SHA256(), + length=32, + salt=session.session_id.encode()[:16], + iterations=100000, + ) + session_key = kdf.derive(pre_shared_key) + + # Create cipher + if session.encryption_algorithm == "ChaCha20-Poly1305": + session.encryption_key = ChaCha20Poly1305(session_key) + else: + session.encryption_key = AESGCM(session_key) + + def cleanup_expired_sessions(self) -> int: + """ + Clean up expired sessions. + + Returns: + Number of sessions removed + """ + expired_sessions = [ + session_id for session_id, session in self.sessions.items() + if not session.is_valid() + ] + + for session_id in expired_sessions: + self.revoke_session(session_id) + + return len(expired_sessions) \ No newline at end of file diff --git a/src/mcp/server/fastmcp/secure/tool.py b/src/mcp/server/fastmcp/secure/tool.py new file mode 100644 index 000000000..554f7e0d7 --- /dev/null +++ b/src/mcp/server/fastmcp/secure/tool.py @@ -0,0 +1,272 @@ +""" +Secure tool implementation with authentication and encryption. +""" + +from __future__ import annotations + +import functools +import inspect +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, cast + +from mcp.server.fastmcp.tools.base import Tool +from mcp.types import ContentBlock, Error + +from .annotations import SecureAnnotations, SecureToolAnnotations +from .identity import ClientIdentity, ToolIdentity +from .session import SecureSession +from .utils import SecureAnnotationProcessor + +if TYPE_CHECKING: + from mcp.server.fastmcp.server import Context, FastMCP + +F = TypeVar('F', bound=Callable[..., Any]) + + +class SecureTool(Tool): + """ + Secure tool with authentication, encryption, and attestation support. + + This extends the base Tool class with security features. + """ + + def __init__( + self, + secure_annotations: SecureAnnotations, + tool_identity: Optional[ToolIdentity] = None, + processor: Optional[SecureAnnotationProcessor] = None, + **kwargs + ): + super().__init__(**kwargs) + self.secure_annotations = secure_annotations + self.tool_identity = tool_identity + self.processor = processor or SecureAnnotationProcessor(tool_identity=tool_identity) + + # Update annotations with security metadata + if not self.annotations: + self.annotations = SecureToolAnnotations(secure=secure_annotations) + elif hasattr(self.annotations, 'extensions'): + self.annotations.extensions["security"] = secure_annotations.to_dict() + + async def run( + self, + arguments: dict[str, Any], + context: Context | None = None, + convert_result: bool = False, + ) -> Any: + """ + Run the secure tool with authentication and encryption. + + This method: + 1. Authenticates the client (if required) + 2. Verifies tool identity (if mutual auth is enabled) + 3. Decrypts input (if encryption is enabled) + 4. Executes the tool + 5. Encrypts output (if encryption is enabled) + 6. Signs the result (if attestation is enabled) + """ + # Extract authentication information from context + auth_header = None + client_cert = None + if context and hasattr(context, 'request_context'): + request = getattr(context.request_context, 'request', None) + if request: + auth_header = request.headers.get('Authorization') + # In production, extract client cert from TLS connection + + # Process secure request (authenticate, decrypt, etc.) + try: + session, processed_args = await self.processor.process_secure_request( + annotations=self.secure_annotations, + auth_header=auth_header, + client_cert=client_cert, + request_data=arguments + ) + except Error as e: + # Log authentication failure + if context: + await context.error(f"Security check failed: {e.message}") + raise + + # If mutual authentication is required, send tool attestation + if self.secure_annotations.require_mutual_auth and self.tool_identity: + attestation = self.tool_identity.to_attestation() + if context: + await context.info(f"Tool attestation: {self.name} (fingerprint: {attestation['fingerprint'][:16]}...)") + + # Log the authenticated execution + if context and session.client_identity: + await context.info( + f"Executing secure tool '{self.name}' for client '{session.client_identity.client_id}' " + f"(auth: {session.client_identity.authentication_method.value})" + ) + + # Execute the actual tool function with processed arguments + try: + # Inject session into arguments if function expects it + sig = inspect.signature(self.fn) + if '_secure_session' in sig.parameters: + processed_args['_secure_session'] = session + + result = await super().run( + arguments=processed_args, + context=context, + convert_result=convert_result + ) + except Exception as e: + # Audit the failure + if self.secure_annotations.audit_log: + if context: + await context.error(f"Tool execution failed: {str(e)}") + raise + + # Process secure response (encrypt, sign, etc.) + secure_result = await self.processor.process_secure_response( + annotations=self.secure_annotations, + session=session, + response_data=result + ) + + return secure_result + + @classmethod + def from_function( + cls, + fn: Callable[..., Any], + secure_annotations: SecureAnnotations, + tool_identity: Optional[ToolIdentity] = None, + name: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + **kwargs + ) -> SecureTool: + """Create a SecureTool from a function.""" + # Create base tool first + base_tool = Tool.from_function( + fn=fn, + name=name, + title=title, + description=description, + **kwargs + ) + + # Create secure tool with same properties + return cls( + fn=base_tool.fn, + name=base_tool.name, + title=base_tool.title, + description=base_tool.description, + parameters=base_tool.parameters, + fn_metadata=base_tool.fn_metadata, + is_async=base_tool.is_async, + context_kwarg=base_tool.context_kwarg, + secure_annotations=secure_annotations, + tool_identity=tool_identity, + annotations=SecureToolAnnotations(secure=secure_annotations) + ) + + +def secure_tool( + # Security parameters + require_auth: bool = False, + auth_methods: Optional[list] = None, + required_permissions: Optional[set[str]] = None, + encrypt_io: bool = False, + require_mutual_auth: bool = False, + security_level: str = "standard", + tool_identity: Optional[ToolIdentity] = None, + + # Standard tool parameters + name: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + structured_output: Optional[bool] = None, +) -> Callable[[F], F]: + """ + Decorator to create a secure tool with authentication and encryption. + + This decorator wraps a function to create a secure MCP tool that supports: + - Client authentication (JWT, certificates, TEE attestation) + - Bidirectional authentication (tool ↔ client) + - Input/output encryption + - Audit logging and compliance + + Args: + require_auth: Whether to require authentication + auth_methods: List of accepted authentication methods + required_permissions: Permissions required to execute the tool + encrypt_io: Whether to encrypt input and output + require_mutual_auth: Whether to require bidirectional authentication + security_level: Security level (standard/high/critical) + tool_identity: Tool identity for attestation + name: Tool name + title: Tool title + description: Tool description + structured_output: Whether to use structured output + + Example: + ```python + @secure_tool( + require_auth=True, + required_permissions={"trade.execute"}, + encrypt_io=True, + require_mutual_auth=True + ) + async def execute_trade(symbol: str, amount: float, ctx: Context) -> str: + # Tool implementation + return f"Trade executed: {symbol} x {amount}" + ``` + """ + from .annotations import AuthMethod + + # Create secure annotations + secure_annotations = SecureAnnotations( + require_auth=require_auth, + auth_methods=auth_methods or [AuthMethod.JWT], + required_permissions=required_permissions or set(), + encrypt_input=encrypt_io, + encrypt_output=encrypt_io, + require_mutual_auth=require_mutual_auth, + security_level=security_level, + ) + + def decorator(func: F) -> F: + # Check if this is being used with FastMCP + # In production, this would be integrated with FastMCP.tool() + + @functools.wraps(func) + async def wrapper(*args, **kwargs): + # This wrapper would be replaced by SecureTool.run in production + # For now, just call the function + return await func(*args, **kwargs) if inspect.iscoroutinefunction(func) else func(*args, **kwargs) + + # Store security metadata on the function + wrapper._secure_annotations = secure_annotations + wrapper._tool_identity = tool_identity + wrapper._is_secure_tool = True + + return cast(F, wrapper) + + return decorator + + +def create_secure_tool_from_function( + fn: Callable[..., Any], + mcp: FastMCP, + secure_annotations: SecureAnnotations, + tool_identity: Optional[ToolIdentity] = None, + **kwargs +) -> None: + """ + Helper function to add a secure tool to a FastMCP instance. + + This would be called internally by FastMCP when a secure_tool decorator is used. + """ + secure_tool_instance = SecureTool.from_function( + fn=fn, + secure_annotations=secure_annotations, + tool_identity=tool_identity, + **kwargs + ) + + # Register with the tool manager + mcp._tool_manager._tools[secure_tool_instance.name] = secure_tool_instance \ No newline at end of file diff --git a/src/mcp/server/fastmcp/secure/utils.py b/src/mcp/server/fastmcp/secure/utils.py new file mode 100644 index 000000000..e0a935c9a --- /dev/null +++ b/src/mcp/server/fastmcp/secure/utils.py @@ -0,0 +1,628 @@ +""" +Utility functions for secure MCP operations. + +Provides encryption, authentication, and security helper functions. +""" + +from __future__ import annotations + +import base64 +import hashlib +import hmac +import json +import os +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional, Tuple + +import jwt +from cryptography import x509 +from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2 + +from mcp.types import Error + +from .annotations import AuthMethod, SecureAnnotations +from .identity import ClientIdentity, ToolIdentity +from .session import SecureSession + + +class SecureAnnotationProcessor: + """ + Processes secure annotations for tools, resources, and prompts. + + This class handles the actual authentication, encryption, and + attestation logic when secure annotations are present. + """ + + def __init__( + self, + tool_identity: Optional[ToolIdentity] = None, + jwt_secret: Optional[str] = None, + trusted_cas: Optional[List[x509.Certificate]] = None, + api_keys: Optional[Dict[str, ClientIdentity]] = None, + ): + self.tool_identity = tool_identity + self.jwt_secret = jwt_secret or os.environ.get("MCP_JWT_SECRET") + self.trusted_cas = trusted_cas or [] + self.api_keys = api_keys or {} + + # Session and rate limit storage + self.sessions: Dict[str, SecureSession] = {} + self.rate_limits: Dict[str, List[datetime]] = {} + + # Audit log (in production, use proper logging system) + self.audit_log_entries: List[dict] = [] + + async def process_secure_request( + self, + annotations: SecureAnnotations, + auth_header: Optional[str] = None, + client_cert: Optional[x509.Certificate] = None, + request_data: Optional[dict[str, Any]] = None, + ) -> Tuple[SecureSession, dict[str, Any]]: + """ + Process a secure request with authentication and encryption. + + Args: + annotations: Security annotations + auth_header: Authorization header + client_cert: Client certificate + request_data: Request data + + Returns: + Tuple of (secure_session, processed_request_data) + + Raises: + Error: If security checks fail + """ + # 1. Authenticate client if required + client_identity = None + if annotations.require_auth: + client_identity = await self._authenticate_client( + annotations.auth_methods, + auth_header, + client_cert, + request_data + ) + + # Check permissions + missing_perms = annotations.required_permissions - client_identity.permissions + if missing_perms: + raise Error( + code=403, + message=f"Missing required permissions: {missing_perms}" + ) + + # 2. Perform mutual authentication if required + if annotations.require_mutual_auth: + if not self.tool_identity: + raise Error( + code=500, + message="Tool identity not configured for mutual authentication" + ) + # Tool attestation is provided through session + + # 3. Create or retrieve session + session = await self._establish_session(client_identity) + + # 4. Verify tool attestation if required + if annotations.require_tool_attestation: + if not self._verify_tool_attestation(annotations): + raise Error( + code=403, + message="Tool attestation verification failed" + ) + + # 5. Check rate limits + if annotations.rate_limit: + self._check_rate_limit( + session.client_identity.client_id if session.client_identity else "anonymous", + annotations.rate_limit, + annotations.rate_limit_per_client + ) + + # 6. Decrypt input if required + processed_data = request_data or {} + if annotations.encrypt_input and session.encryption_key: + processed_data = self._decrypt_request_data(session, processed_data) + + # 7. Verify message integrity if required + if annotations.require_integrity_check: + self._verify_message_integrity(processed_data) + + # 8. Check replay protection if required + if annotations.require_replay_protection: + if not self._check_replay_protection(session, processed_data): + raise Error(code=400, message="Replay attack detected") + + # 9. Audit log + if annotations.audit_log: + self._audit_log( + session=session, + action="request", + include_data=annotations.audit_include_inputs, + data=processed_data if annotations.audit_include_inputs else None + ) + + return session, processed_data + + async def process_secure_response( + self, + annotations: SecureAnnotations, + session: SecureSession, + response_data: Any, + ) -> Any: + """ + Process a secure response with encryption and signing. + + Args: + annotations: Security annotations + session: Secure session + response_data: Response data + + Returns: + Processed response data + """ + # 1. Audit log + if annotations.audit_log: + self._audit_log( + session=session, + action="response", + include_data=annotations.audit_include_outputs, + data=response_data if annotations.audit_include_outputs else None + ) + + # 2. Encrypt output if required + if annotations.encrypt_output and session.encryption_key: + response_data = self._encrypt_response_data(session, response_data) + + # 3. Add integrity signature if required + if annotations.require_integrity_check: + response_data = self._add_integrity_signature(response_data) + + # 4. Sign response if tool signature is required + if annotations.tool_signature_required and self.tool_identity: + response_data = self._sign_response(response_data) + + # 5. Add session metadata + if isinstance(response_data, dict): + response_data["_session"] = { + "id": session.session_id[:8] + "...", # Truncated for security + "authenticated": session.client_identity is not None, + "encrypted": annotations.encrypt_output, + } + + return response_data + + async def _authenticate_client( + self, + auth_methods: List[AuthMethod], + auth_header: Optional[str], + client_cert: Optional[x509.Certificate], + request_data: Optional[dict], + ) -> ClientIdentity: + """Authenticate client using available methods.""" + + # Try JWT authentication + if AuthMethod.JWT in auth_methods and auth_header: + identity = self._authenticate_jwt(auth_header) + if identity: + return identity + + # Try API key authentication + if AuthMethod.API_KEY in auth_methods: + api_key = None + if auth_header and auth_header.startswith("Bearer "): + api_key = auth_header[7:] + elif request_data and "api_key" in request_data: + api_key = request_data["api_key"] + + if api_key: + identity = self._authenticate_api_key(api_key) + if identity: + return identity + + # Try certificate authentication + if AuthMethod.CERTIFICATE in auth_methods and client_cert: + identity = self._authenticate_certificate(client_cert) + if identity: + return identity + + # Try mutual TLS + if AuthMethod.MTLS in auth_methods and client_cert: + identity = self._authenticate_mtls(client_cert) + if identity: + return identity + + raise Error(code=401, message="Authentication failed") + + def _authenticate_jwt(self, auth_header: str) -> Optional[ClientIdentity]: + """Authenticate using JWT token.""" + if not auth_header.startswith("Bearer "): + return None + + token = auth_header[7:] + + try: + # Decode and verify JWT + claims = jwt.decode( + token, + self.jwt_secret, + algorithms=["HS256", "RS256", "ES256"] + ) + + return ClientIdentity( + client_id=claims.get("sub", "unknown"), + authentication_method=AuthMethod.JWT, + credentials=token, + permissions=set(claims.get("permissions", [])), + email=claims.get("email"), + organization=claims.get("org"), + expires_at=datetime.fromtimestamp(claims.get("exp", 0)), + metadata={"claims": claims} + ) + except jwt.InvalidTokenError: + return None + + def _authenticate_api_key(self, api_key: str) -> Optional[ClientIdentity]: + """Authenticate using API key.""" + return self.api_keys.get(api_key) + + def _authenticate_certificate(self, cert: x509.Certificate) -> Optional[ClientIdentity]: + """Authenticate using X.509 certificate.""" + # Verify certificate against trusted CAs + for ca in self.trusted_cas: + try: + ca.public_key().verify( + cert.signature, + cert.tbs_certificate_bytes, + cert.signature_algorithm_oid._name + ) + + # Extract client info from certificate + from cryptography.x509.oid import NameOID + + common_name = cert.subject.get_attributes_for_oid( + NameOID.COMMON_NAME + )[0].value + + org = None + org_attrs = cert.subject.get_attributes_for_oid(NameOID.ORGANIZATION_NAME) + if org_attrs: + org = org_attrs[0].value + + return ClientIdentity( + client_id=common_name, + authentication_method=AuthMethod.CERTIFICATE, + credentials=cert, + permissions={"read", "write", "execute"}, # Extract from cert extensions + organization=org, + certificate=cert, + certificate_fingerprint=hashlib.sha256( + cert.public_bytes(x509.Encoding.DER) + ).hexdigest(), + ) + except Exception: + continue + + return None + + def _authenticate_mtls(self, cert: x509.Certificate) -> Optional[ClientIdentity]: + """Authenticate using mutual TLS.""" + # Similar to certificate auth but with bidirectional verification + identity = self._authenticate_certificate(cert) + + if identity and self.tool_identity: + # Verify that client also verified our tool certificate + # This would be handled at the TLS layer in production + identity.metadata["mtls_verified"] = True + + return identity + + async def _establish_session( + self, + client_identity: Optional[ClientIdentity] + ) -> SecureSession: + """Establish or retrieve a secure session.""" + # For simplicity, create a new session each time + # In production, implement session caching + + session_id = base64.b64encode(os.urandom(32)).decode() + + # Create encryption key if we have a client + encryption_key = None + if client_identity: + key_bytes = AESGCM.generate_key(bit_length=256) + encryption_key = AESGCM(key_bytes) + + session = SecureSession( + session_id=session_id, + client_identity=client_identity, + tool_identity=self.tool_identity, + established_at=datetime.utcnow(), + expires_at=datetime.utcnow() + timedelta(hours=1), + encryption_key=encryption_key + ) + + self.sessions[session_id] = session + return session + + def _verify_tool_attestation(self, annotations: SecureAnnotations) -> bool: + """Verify tool attestation matches requirements.""" + if not self.tool_identity: + return False + + # Check certificate fingerprint if specified + if annotations.tool_certificate_fingerprint: + if self.tool_identity.fingerprint != annotations.tool_certificate_fingerprint: + return False + + # Check attestation type if specified + if annotations.attestation_type: + if self.tool_identity.attestation_type != annotations.attestation_type: + return False + + return True + + def _check_rate_limit( + self, + client_id: str, + limit: int, + per_client: bool + ) -> None: + """Check and enforce rate limits.""" + key = client_id if per_client else "global" + now = datetime.utcnow() + + # Clean old entries + if key in self.rate_limits: + self.rate_limits[key] = [ + t for t in self.rate_limits[key] + if (now - t).total_seconds() < 60 + ] + else: + self.rate_limits[key] = [] + + # Check limit + if len(self.rate_limits[key]) >= limit: + raise Error(code=429, message="Rate limit exceeded") + + # Add current request + self.rate_limits[key].append(now) + + def _decrypt_request_data( + self, + session: SecureSession, + data: dict[str, Any] + ) -> dict[str, Any]: + """Decrypt request data.""" + decrypted = {} + for key, value in data.items(): + if isinstance(value, str) and value.startswith("ENC:"): + encrypted_bytes = base64.b64decode(value[4:]) + decrypted_bytes = session.decrypt(encrypted_bytes) + decrypted[key] = json.loads(decrypted_bytes) + elif isinstance(value, dict): + decrypted[key] = self._decrypt_request_data(session, value) + else: + decrypted[key] = value + return decrypted + + def _encrypt_response_data( + self, + session: SecureSession, + data: Any + ) -> dict[str, Any]: + """Encrypt response data.""" + json_data = json.dumps(data) + encrypted = session.encrypt(json_data.encode()) + + return { + "encrypted": True, + "algorithm": session.encryption_algorithm, + "data": "ENC:" + base64.b64encode(encrypted).decode(), + "session_id": session.session_id + } + + def _verify_message_integrity(self, data: dict) -> bool: + """Verify message integrity signature.""" + if "_integrity" not in data: + return True # No integrity check provided + + integrity = data.pop("_integrity") + + # Compute expected hash + data_str = json.dumps(data, sort_keys=True) + expected_hash = hashlib.sha256(data_str.encode()).hexdigest() + + return hmac.compare_digest(integrity, expected_hash) + + def _add_integrity_signature(self, data: Any) -> dict: + """Add integrity signature to response.""" + if isinstance(data, dict): + data_copy = data.copy() + else: + data_copy = {"value": data} + + # Compute hash + data_str = json.dumps(data_copy, sort_keys=True) + integrity = hashlib.sha256(data_str.encode()).hexdigest() + + data_copy["_integrity"] = integrity + return data_copy + + def _check_replay_protection( + self, + session: SecureSession, + data: dict + ) -> bool: + """Check for replay attacks.""" + if "_request_token" not in data: + return False + + token = data.pop("_request_token") + return session.verify_request_token(token) + + def _sign_response(self, data: Any) -> dict[str, Any]: + """Sign response data with tool identity.""" + if not self.tool_identity: + return data if isinstance(data, dict) else {"value": data} + + # Prepare data for signing + if isinstance(data, dict): + sign_data = data.copy() + else: + sign_data = {"value": data} + + # Add timestamp + sign_data["_timestamp"] = datetime.utcnow().isoformat() + + # Sign the data + json_data = json.dumps(sign_data, sort_keys=True) + signature = self.tool_identity.sign_data(json_data.encode()) + + return { + "data": sign_data, + "signature": base64.b64encode(signature).decode(), + "tool_id": self.tool_identity.tool_id, + "tool_fingerprint": self.tool_identity.fingerprint[:16] + "..." + } + + def _audit_log( + self, + session: Optional[SecureSession], + action: str, + include_data: bool, + data: Any = None + ) -> None: + """Create audit log entry.""" + log_entry = { + "timestamp": datetime.utcnow().isoformat(), + "action": action, + } + + if session: + log_entry["session_id"] = session.session_id + if session.client_identity: + log_entry["client_id"] = session.client_identity.client_id + log_entry["auth_method"] = session.client_identity.authentication_method.value + if session.tool_identity: + log_entry["tool_id"] = session.tool_identity.tool_id + + if include_data and data is not None: + # Hash sensitive data for audit + if isinstance(data, (dict, list)): + log_entry["data_hash"] = hashlib.sha256( + json.dumps(data, sort_keys=True).encode() + ).hexdigest() + else: + log_entry["data_hash"] = hashlib.sha256( + str(data).encode() + ).hexdigest() + + self.audit_log_entries.append(log_entry) + + # In production, write to proper audit system + # For now, just keep in memory + + +# Convenience functions for encryption/decryption +def encrypt_data(data: str, key: bytes) -> str: + """ + Encrypt data using AES-256-GCM. + + Args: + data: Data to encrypt + key: 256-bit encryption key + + Returns: + Base64-encoded encrypted data + """ + cipher = AESGCM(key) + nonce = os.urandom(12) + ciphertext = cipher.encrypt(nonce, data.encode(), None) + return base64.b64encode(nonce + ciphertext).decode() + + +def decrypt_data(encrypted: str, key: bytes) -> str: + """ + Decrypt data encrypted with AES-256-GCM. + + Args: + encrypted: Base64-encoded encrypted data + key: 256-bit encryption key + + Returns: + Decrypted data + """ + cipher = AESGCM(key) + raw = base64.b64decode(encrypted) + nonce, ciphertext = raw[:12], raw[12:] + plaintext = cipher.decrypt(nonce, ciphertext, None) + return plaintext.decode() + + +def generate_session_key(password: str, salt: Optional[bytes] = None) -> bytes: + """ + Generate a session key from a password. + + Args: + password: Password to derive key from + salt: Optional salt (will generate if not provided) + + Returns: + 256-bit key suitable for AES-256 + """ + if salt is None: + salt = os.urandom(16) + + kdf = PBKDF2( + algorithm=hashes.SHA256(), + length=32, + salt=salt, + iterations=100000, + ) + return kdf.derive(password.encode()) + + +def verify_signature( + data: bytes, + signature: bytes, + public_key_pem: bytes +) -> bool: + """ + Verify a digital signature. + + Args: + data: Data that was signed + signature: Digital signature + public_key_pem: Public key in PEM format + + Returns: + True if signature is valid + """ + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import ec, rsa, padding + + try: + public_key = serialization.load_pem_public_key(public_key_pem) + + if isinstance(public_key, rsa.RSAPublicKey): + public_key.verify( + signature, + data, + padding.PSS( + mgf=padding.MGF1(hashes.SHA256()), + salt_length=padding.PSS.MAX_LENGTH + ), + hashes.SHA256() + ) + elif isinstance(public_key, ec.EllipticCurvePublicKey): + public_key.verify(signature, data, ec.ECDSA(hashes.SHA256())) + else: + return False + + return True + except Exception: + return False \ No newline at end of file