|
| 1 | +"""Encrypted Session wrapper for secure conversation storage. |
| 2 | +
|
| 3 | +This module provides transparent encryption for session storage with automatic |
| 4 | +expiration of old data. When TTL expires, expired items are silently skipped. |
| 5 | +
|
| 6 | +Usage:: |
| 7 | +
|
| 8 | + from agents.extensions.memory import EncryptedSession, SQLAlchemySession |
| 9 | +
|
| 10 | + # Create underlying session (e.g. SQLAlchemySession) |
| 11 | + underlying_session = SQLAlchemySession.from_url( |
| 12 | + session_id="user-123", |
| 13 | + url="postgresql+asyncpg://app:[email protected]/agents", |
| 14 | + create_tables=True, |
| 15 | + ) |
| 16 | +
|
| 17 | + # Wrap with encryption and TTL-based expiration |
| 18 | + session = EncryptedSession( |
| 19 | + session_id="user-123", |
| 20 | + underlying_session=underlying_session, |
| 21 | + encryption_key="your-encryption-key", |
| 22 | + ttl=600, # 10 minutes |
| 23 | + ) |
| 24 | +
|
| 25 | + await Runner.run(agent, "Hello", session=session) |
| 26 | +""" |
| 27 | + |
| 28 | +from __future__ import annotations |
| 29 | + |
| 30 | +import base64 |
| 31 | +import json |
| 32 | +from typing import Any, cast |
| 33 | + |
| 34 | +from cryptography.fernet import Fernet, InvalidToken |
| 35 | +from cryptography.hazmat.primitives import hashes |
| 36 | +from cryptography.hazmat.primitives.kdf.hkdf import HKDF |
| 37 | +from typing_extensions import Literal, TypedDict, TypeGuard |
| 38 | + |
| 39 | +from ...items import TResponseInputItem |
| 40 | +from ...memory.session import SessionABC |
| 41 | + |
| 42 | + |
| 43 | +class EncryptedEnvelope(TypedDict): |
| 44 | + """TypedDict for encrypted message envelopes stored in the underlying session.""" |
| 45 | + |
| 46 | + __enc__: Literal[1] |
| 47 | + v: int |
| 48 | + kid: str |
| 49 | + payload: str |
| 50 | + |
| 51 | + |
| 52 | +def _ensure_fernet_key_bytes(master_key: str) -> bytes: |
| 53 | + """ |
| 54 | + Accept either a Fernet key (urlsafe-b64, 32 bytes after decode) or a raw string. |
| 55 | + Returns raw bytes suitable for HKDF input. |
| 56 | + """ |
| 57 | + if not master_key: |
| 58 | + raise ValueError("encryption_key not set; required for EncryptedSession.") |
| 59 | + try: |
| 60 | + key_bytes = base64.urlsafe_b64decode(master_key) |
| 61 | + if len(key_bytes) == 32: |
| 62 | + return key_bytes |
| 63 | + except Exception: |
| 64 | + pass |
| 65 | + return master_key.encode("utf-8") |
| 66 | + |
| 67 | + |
| 68 | +def _derive_session_fernet_key(master_key_bytes: bytes, session_id: str) -> Fernet: |
| 69 | + hkdf = HKDF( |
| 70 | + algorithm=hashes.SHA256(), |
| 71 | + length=32, |
| 72 | + salt=session_id.encode("utf-8"), |
| 73 | + info=b"agents.session-store.hkdf.v1", |
| 74 | + ) |
| 75 | + derived = hkdf.derive(master_key_bytes) |
| 76 | + return Fernet(base64.urlsafe_b64encode(derived)) |
| 77 | + |
| 78 | + |
| 79 | +def _to_json_bytes(obj: Any) -> bytes: |
| 80 | + return json.dumps(obj, ensure_ascii=False, separators=(",", ":"), default=str).encode("utf-8") |
| 81 | + |
| 82 | + |
| 83 | +def _from_json_bytes(data: bytes) -> Any: |
| 84 | + return json.loads(data.decode("utf-8")) |
| 85 | + |
| 86 | + |
| 87 | +def _is_encrypted_envelope(item: object) -> TypeGuard[EncryptedEnvelope]: |
| 88 | + """Type guard to check if an item is an encrypted envelope.""" |
| 89 | + return ( |
| 90 | + isinstance(item, dict) |
| 91 | + and item.get("__enc__") == 1 |
| 92 | + and "payload" in item |
| 93 | + and "kid" in item |
| 94 | + and "v" in item |
| 95 | + ) |
| 96 | + |
| 97 | + |
| 98 | +class EncryptedSession(SessionABC): |
| 99 | + """Encrypted wrapper for Session implementations with TTL-based expiration. |
| 100 | +
|
| 101 | + This class wraps any SessionABC implementation to provide transparent |
| 102 | + encryption/decryption of stored items using Fernet encryption with |
| 103 | + per-session key derivation and automatic expiration of old data. |
| 104 | +
|
| 105 | + When items expire (exceed TTL), they are silently skipped during retrieval. |
| 106 | +
|
| 107 | + Note: Expired tokens are rejected based on the system clock of the application server. |
| 108 | + To avoid valid tokens being rejected due to clock drift, ensure all servers in |
| 109 | + your environment are synchronized using NTP. |
| 110 | + """ |
| 111 | + |
| 112 | + def __init__( |
| 113 | + self, |
| 114 | + session_id: str, |
| 115 | + underlying_session: SessionABC, |
| 116 | + encryption_key: str, |
| 117 | + ttl: int = 600, |
| 118 | + ): |
| 119 | + """ |
| 120 | + Args: |
| 121 | + session_id: ID for this session |
| 122 | + underlying_session: The real session store (e.g. SQLiteSession, SQLAlchemySession) |
| 123 | + encryption_key: Master key (Fernet key or raw secret) |
| 124 | + ttl: Token time-to-live in seconds (default 10 min) |
| 125 | + """ |
| 126 | + self.session_id = session_id |
| 127 | + self.underlying_session = underlying_session |
| 128 | + self.ttl = ttl |
| 129 | + |
| 130 | + master = _ensure_fernet_key_bytes(encryption_key) |
| 131 | + self.cipher = _derive_session_fernet_key(master, session_id) |
| 132 | + self._kid = "hkdf-v1" |
| 133 | + self._ver = 1 |
| 134 | + |
| 135 | + def __getattr__(self, name): |
| 136 | + return getattr(self.underlying_session, name) |
| 137 | + |
| 138 | + def _wrap(self, item: TResponseInputItem) -> EncryptedEnvelope: |
| 139 | + if isinstance(item, dict): |
| 140 | + payload = item |
| 141 | + elif hasattr(item, "model_dump"): |
| 142 | + payload = item.model_dump() |
| 143 | + elif hasattr(item, "__dict__"): |
| 144 | + payload = item.__dict__ |
| 145 | + else: |
| 146 | + payload = dict(item) |
| 147 | + |
| 148 | + token = self.cipher.encrypt(_to_json_bytes(payload)).decode("utf-8") |
| 149 | + return {"__enc__": 1, "v": self._ver, "kid": self._kid, "payload": token} |
| 150 | + |
| 151 | + def _unwrap(self, item: TResponseInputItem | EncryptedEnvelope) -> TResponseInputItem | None: |
| 152 | + if not _is_encrypted_envelope(item): |
| 153 | + return cast(TResponseInputItem, item) |
| 154 | + |
| 155 | + try: |
| 156 | + token = item["payload"].encode("utf-8") |
| 157 | + plaintext = self.cipher.decrypt(token, ttl=self.ttl) |
| 158 | + return cast(TResponseInputItem, _from_json_bytes(plaintext)) |
| 159 | + except (InvalidToken, KeyError): |
| 160 | + return None |
| 161 | + |
| 162 | + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: |
| 163 | + encrypted_items = await self.underlying_session.get_items(limit) |
| 164 | + valid_items: list[TResponseInputItem] = [] |
| 165 | + for enc in encrypted_items: |
| 166 | + item = self._unwrap(enc) |
| 167 | + if item is not None: |
| 168 | + valid_items.append(item) |
| 169 | + return valid_items |
| 170 | + |
| 171 | + async def add_items(self, items: list[TResponseInputItem]) -> None: |
| 172 | + wrapped: list[EncryptedEnvelope] = [self._wrap(it) for it in items] |
| 173 | + await self.underlying_session.add_items(cast(list[TResponseInputItem], wrapped)) |
| 174 | + |
| 175 | + async def pop_item(self) -> TResponseInputItem | None: |
| 176 | + while True: |
| 177 | + enc = await self.underlying_session.pop_item() |
| 178 | + if not enc: |
| 179 | + return None |
| 180 | + item = self._unwrap(enc) |
| 181 | + if item is not None: |
| 182 | + return item |
| 183 | + |
| 184 | + async def clear_session(self) -> None: |
| 185 | + await self.underlying_session.clear_session() |
0 commit comments