Skip to content

Commit 4aed8b2

Browse files
committed
feat: add encrypted session support with TTL expiration
1 parent 50a909a commit 4aed8b2

File tree

3 files changed

+517
-15
lines changed

3 files changed

+517
-15
lines changed
Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,42 @@
1-
2-
"""Session memory backends living in the extensions namespace.
3-
4-
This package contains optional, production-grade session implementations that
5-
introduce extra third-party dependencies (database drivers, ORMs, etc.). They
6-
conform to the :class:`agents.memory.session.Session` protocol so they can be
7-
used as a drop-in replacement for :class:`agents.memory.session.SQLiteSession`.
8-
"""
9-
from __future__ import annotations
10-
11-
from .sqlalchemy_session import SQLAlchemySession # noqa: F401
12-
13-
__all__: list[str] = [
14-
"SQLAlchemySession",
15-
]
1+
"""Session memory backends living in the extensions namespace.
2+
3+
This package contains optional, production-grade session implementations that
4+
introduce extra third-party dependencies (database drivers, ORMs, etc.). They
5+
conform to the :class:`agents.memory.session.Session` protocol so they can be
6+
used as a drop-in replacement for :class:`agents.memory.session.SQLiteSession`.
7+
"""
8+
9+
from __future__ import annotations
10+
11+
from typing import Any
12+
13+
__all__: list[str] = [
14+
"EncryptedSession",
15+
"SQLAlchemySession",
16+
]
17+
18+
19+
def __getattr__(name: str) -> Any:
20+
if name == "EncryptedSession":
21+
try:
22+
from .encrypt_session import EncryptedSession # noqa: F401
23+
24+
return EncryptedSession
25+
except ModuleNotFoundError as e:
26+
raise ImportError(
27+
"EncryptedSession requires the 'cryptography' extra. "
28+
"Install it with: pip install openai-agents[encrypt]"
29+
) from e
30+
31+
if name == "SQLAlchemySession":
32+
try:
33+
from .sqlalchemy_session import SQLAlchemySession # noqa: F401
34+
35+
return SQLAlchemySession
36+
except ModuleNotFoundError as e:
37+
raise ImportError(
38+
"SQLAlchemySession requires the 'sqlalchemy' extra. "
39+
"Install it with: pip install openai-agents[sqlalchemy]"
40+
) from e
41+
42+
raise AttributeError(f"module {__name__} has no attribute {name}")
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
"""Encrypted Session wrapper for secure conversation storage.
2+
3+
This module provides transparent encryption for session storage with automatic
4+
expiration of old data. When TTL expires, expired items are silently skipped.
5+
6+
Usage::
7+
8+
from agents.extensions.memory import EncryptedSession, SQLAlchemySession
9+
10+
# Create underlying session (e.g. SQLAlchemySession)
11+
underlying_session = SQLAlchemySession.from_url(
12+
session_id="user-123",
13+
url="postgresql+asyncpg://app:[email protected]/agents",
14+
create_tables=True,
15+
)
16+
17+
# Wrap with encryption and TTL-based expiration
18+
session = EncryptedSession(
19+
session_id="user-123",
20+
underlying_session=underlying_session,
21+
encryption_key="your-encryption-key",
22+
ttl=600, # 10 minutes
23+
)
24+
25+
await Runner.run(agent, "Hello", session=session)
26+
"""
27+
28+
from __future__ import annotations
29+
30+
import base64
31+
import json
32+
from typing import Any, Literal, TypedDict, TypeGuard, cast
33+
34+
from cryptography.fernet import Fernet, InvalidToken
35+
from cryptography.hazmat.primitives import hashes
36+
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
37+
38+
from ...items import TResponseInputItem
39+
from ...memory.session import SessionABC
40+
41+
42+
class EncryptedEnvelope(TypedDict):
43+
"""TypedDict for encrypted message envelopes stored in the underlying session."""
44+
45+
__enc__: Literal[1]
46+
v: int
47+
kid: str
48+
payload: str
49+
50+
51+
def _ensure_fernet_key_bytes(master_key: str) -> bytes:
52+
"""
53+
Accept either a Fernet key (urlsafe-b64, 32 bytes after decode) or a raw string.
54+
Returns raw bytes suitable for HKDF input.
55+
"""
56+
if not master_key:
57+
raise ValueError("encryption_key not set; required for EncryptedSession.")
58+
try:
59+
key_bytes = base64.urlsafe_b64decode(master_key)
60+
if len(key_bytes) == 32:
61+
return key_bytes
62+
except Exception:
63+
pass
64+
return master_key.encode("utf-8")
65+
66+
67+
def _derive_session_fernet_key(master_key_bytes: bytes, session_id: str) -> Fernet:
68+
hkdf = HKDF(
69+
algorithm=hashes.SHA256(),
70+
length=32,
71+
salt=session_id.encode("utf-8"),
72+
info=b"agents.session-store.hkdf.v1",
73+
)
74+
derived = hkdf.derive(master_key_bytes)
75+
return Fernet(base64.urlsafe_b64encode(derived))
76+
77+
78+
def _to_json_bytes(obj: Any) -> bytes:
79+
return json.dumps(obj, ensure_ascii=False, separators=(",", ":"), default=str).encode("utf-8")
80+
81+
82+
def _from_json_bytes(data: bytes) -> Any:
83+
return json.loads(data.decode("utf-8"))
84+
85+
86+
def _is_encrypted_envelope(item: object) -> TypeGuard[EncryptedEnvelope]:
87+
"""Type guard to check if an item is an encrypted envelope."""
88+
return (
89+
isinstance(item, dict)
90+
and item.get("__enc__") == 1
91+
and "payload" in item
92+
and "kid" in item
93+
and "v" in item
94+
)
95+
96+
97+
class EncryptedSession(SessionABC):
98+
"""Encrypted wrapper for Session implementations with TTL-based expiration.
99+
100+
This class wraps any SessionABC implementation to provide transparent
101+
encryption/decryption of stored items using Fernet encryption with
102+
per-session key derivation and automatic expiration of old data.
103+
104+
When items expire (exceed TTL), they are silently skipped during retrieval.
105+
106+
Note: Expired tokens are rejected based on the system clock of the application server.
107+
To avoid valid tokens being rejected due to clock drift, ensure all servers in
108+
your environment are synchronized using NTP.
109+
"""
110+
111+
def __init__(
112+
self,
113+
session_id: str,
114+
underlying_session: SessionABC,
115+
encryption_key: str,
116+
ttl: int = 600,
117+
):
118+
"""
119+
Args:
120+
session_id: ID for this session
121+
underlying_session: The real session store (e.g. SQLiteSession, SQLAlchemySession)
122+
encryption_key: Master key (Fernet key or raw secret)
123+
ttl: Token time-to-live in seconds (default 10 min)
124+
"""
125+
self.session_id = session_id
126+
self.underlying_session = underlying_session
127+
self.ttl = ttl
128+
129+
master = _ensure_fernet_key_bytes(encryption_key)
130+
self.cipher = _derive_session_fernet_key(master, session_id)
131+
self._kid = "hkdf-v1"
132+
self._ver = 1
133+
134+
def _wrap(self, item: TResponseInputItem) -> EncryptedEnvelope:
135+
if isinstance(item, dict):
136+
payload = item
137+
elif hasattr(item, "model_dump"):
138+
payload = item.model_dump()
139+
elif hasattr(item, "__dict__"):
140+
payload = item.__dict__
141+
else:
142+
payload = dict(item)
143+
144+
token = self.cipher.encrypt(_to_json_bytes(payload)).decode("utf-8")
145+
return {"__enc__": 1, "v": self._ver, "kid": self._kid, "payload": token}
146+
147+
def _unwrap(self, item: TResponseInputItem | EncryptedEnvelope) -> TResponseInputItem | None:
148+
if not _is_encrypted_envelope(item):
149+
return cast(TResponseInputItem, item)
150+
151+
try:
152+
token = item["payload"].encode("utf-8")
153+
plaintext = self.cipher.decrypt(token, ttl=self.ttl)
154+
return cast(TResponseInputItem, _from_json_bytes(plaintext))
155+
except (InvalidToken, KeyError):
156+
return None
157+
158+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
159+
encrypted_items = await self.underlying_session.get_items(limit)
160+
valid_items: list[TResponseInputItem] = []
161+
for enc in encrypted_items:
162+
item = self._unwrap(enc)
163+
if item is not None:
164+
valid_items.append(item)
165+
return valid_items
166+
167+
async def add_items(self, items: list[TResponseInputItem]) -> None:
168+
wrapped: list[EncryptedEnvelope] = [self._wrap(it) for it in items]
169+
await self.underlying_session.add_items(cast(list[TResponseInputItem], wrapped))
170+
171+
async def pop_item(self) -> TResponseInputItem | None:
172+
while True:
173+
enc = await self.underlying_session.pop_item()
174+
if not enc:
175+
return None
176+
item = self._unwrap(enc)
177+
if item is not None:
178+
return item
179+
180+
async def clear_session(self) -> None:
181+
await self.underlying_session.clear_session()

0 commit comments

Comments
 (0)