Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions examples/memory/encrypted_session_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""
Example demonstrating encrypted session memory functionality.
This example shows how to use encrypted session memory to maintain conversation history
across multiple agent runs with automatic encryption and TTL-based expiration.
The EncryptedSession wrapper provides transparent encryption over any underlying session.
"""

import asyncio
from typing import cast

from agents import Agent, Runner, SQLiteSession
from agents.extensions.memory import EncryptedSession
from agents.extensions.memory.encrypt_session import EncryptedEnvelope


async def main():
# Create an agent
agent = Agent(
name="Assistant",
instructions="Reply very concisely.",
)

# Create an underlying session (SQLiteSession in this example)
session_id = "conversation_123"
underlying_session = SQLiteSession(session_id)

# Wrap with encrypted session for automatic encryption and TTL
session = EncryptedSession(
session_id=session_id,
underlying_session=underlying_session,
encryption_key="my-secret-encryption-key",
ttl=3600, # 1 hour TTL for messages
)

print("=== Encrypted Session Example ===")
print("The agent will remember previous messages automatically with encryption.\n")

# First turn
print("First turn:")
print("User: What city is the Golden Gate Bridge in?")
result = await Runner.run(
agent,
"What city is the Golden Gate Bridge in?",
session=session,
)
print(f"Assistant: {result.final_output}")
print()

# Second turn - the agent will remember the previous conversation
print("Second turn:")
print("User: What state is it in?")
result = await Runner.run(agent, "What state is it in?", session=session)
print(f"Assistant: {result.final_output}")
print()

# Third turn - continuing the conversation
print("Third turn:")
print("User: What's the population of that state?")
result = await Runner.run(
agent,
"What's the population of that state?",
session=session,
)
print(f"Assistant: {result.final_output}")
print()

print("=== Conversation Complete ===")
print("Notice how the agent remembered the context from previous turns!")
print("All conversation history was automatically encrypted and stored securely.")

# Demonstrate the limit parameter - get only the latest 2 items
print("\n=== Latest Items Demo ===")
latest_items = await session.get_items(limit=2)
print("Latest 2 items (automatically decrypted):")
for i, msg in enumerate(latest_items, 1):
role = msg.get("role", "unknown")
content = msg.get("content", "")
print(f" {i}. {role}: {content}")

print(f"\nFetched {len(latest_items)} out of total conversation history.")

# Get all items to show the difference
all_items = await session.get_items()
print(f"Total items in session: {len(all_items)}")

# Show that underlying storage is encrypted
print("\n=== Encryption Demo ===")
print("Checking underlying storage to verify encryption...")
raw_items = await underlying_session.get_items()
print("Raw encrypted items in underlying storage:")
for i, item in enumerate(raw_items, 1):
if isinstance(item, dict) and item.get("__enc__") == 1:
enc_item = cast(EncryptedEnvelope, item)
print(
f" {i}. Encrypted envelope: __enc__={enc_item['__enc__']}, "
f"payload length={len(enc_item['payload'])}"
)
else:
print(f" {i}. Unencrypted item: {item}")

print(f"\nAll {len(raw_items)} items are stored encrypted with TTL-based expiration.")

# Clean up
underlying_session.close()


if __name__ == "__main__":
asyncio.run(main())
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ viz = ["graphviz>=0.17"]
litellm = ["litellm>=1.67.4.post1, <2"]
realtime = ["websockets>=15.0, <16"]
sqlalchemy = ["SQLAlchemy>=2.0", "asyncpg>=0.29.0"]
encrypt = ["cryptography>=45.0, <46"]

[dependency-groups]
dev = [
Expand All @@ -65,6 +66,7 @@ dev = [
"eval-type-backport>=0.2.2",
"fastapi >= 0.110.0, <1",
"aiosqlite>=0.21.0",
"cryptography>=45.0, <46",
]

[tool.uv.workspace]
Expand Down
57 changes: 42 additions & 15 deletions src/agents/extensions/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,42 @@
"""Session memory backends living in the extensions namespace.
This package contains optional, production-grade session implementations that
introduce extra third-party dependencies (database drivers, ORMs, etc.). They
conform to the :class:`agents.memory.session.Session` protocol so they can be
used as a drop-in replacement for :class:`agents.memory.session.SQLiteSession`.
"""

from __future__ import annotations

from .sqlalchemy_session import SQLAlchemySession # noqa: F401

__all__: list[str] = [
"SQLAlchemySession",
]
"""Session memory backends living in the extensions namespace.
This package contains optional, production-grade session implementations that
introduce extra third-party dependencies (database drivers, ORMs, etc.). They
conform to the :class:`agents.memory.session.Session` protocol so they can be
used as a drop-in replacement for :class:`agents.memory.session.SQLiteSession`.
"""

from __future__ import annotations

from typing import Any

__all__: list[str] = [
"EncryptedSession",
"SQLAlchemySession",
]


def __getattr__(name: str) -> Any:
if name == "EncryptedSession":
try:
from .encrypt_session import EncryptedSession # noqa: F401

return EncryptedSession
except ModuleNotFoundError as e:
raise ImportError(
"EncryptedSession requires the 'cryptography' extra. "
"Install it with: pip install openai-agents[encrypt]"
) from e

if name == "SQLAlchemySession":
try:
from .sqlalchemy_session import SQLAlchemySession # noqa: F401

return SQLAlchemySession
except ModuleNotFoundError as e:
raise ImportError(
"SQLAlchemySession requires the 'sqlalchemy' extra. "
"Install it with: pip install openai-agents[sqlalchemy]"
) from e

raise AttributeError(f"module {__name__} has no attribute {name}")
185 changes: 185 additions & 0 deletions src/agents/extensions/memory/encrypt_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""Encrypted Session wrapper for secure conversation storage.

This module provides transparent encryption for session storage with automatic
expiration of old data. When TTL expires, expired items are silently skipped.

Usage::

from agents.extensions.memory import EncryptedSession, SQLAlchemySession

# Create underlying session (e.g. SQLAlchemySession)
underlying_session = SQLAlchemySession.from_url(
session_id="user-123",
url="postgresql+asyncpg://app:[email protected]/agents",
create_tables=True,
)

# Wrap with encryption and TTL-based expiration
session = EncryptedSession(
session_id="user-123",
underlying_session=underlying_session,
encryption_key="your-encryption-key",
ttl=600, # 10 minutes
)

await Runner.run(agent, "Hello", session=session)
"""

from __future__ import annotations

import base64
import json
from typing import Any, cast

from cryptography.fernet import Fernet, InvalidToken
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from typing_extensions import Literal, TypedDict, TypeGuard

from ...items import TResponseInputItem
from ...memory.session import SessionABC


class EncryptedEnvelope(TypedDict):
"""TypedDict for encrypted message envelopes stored in the underlying session."""

__enc__: Literal[1]
v: int
kid: str
payload: str


def _ensure_fernet_key_bytes(master_key: str) -> bytes:
"""
Accept either a Fernet key (urlsafe-b64, 32 bytes after decode) or a raw string.
Returns raw bytes suitable for HKDF input.
"""
if not master_key:
raise ValueError("encryption_key not set; required for EncryptedSession.")
try:
key_bytes = base64.urlsafe_b64decode(master_key)
if len(key_bytes) == 32:
return key_bytes
except Exception:
pass
return master_key.encode("utf-8")


def _derive_session_fernet_key(master_key_bytes: bytes, session_id: str) -> Fernet:
hkdf = HKDF(
algorithm=hashes.SHA256(),
length=32,
salt=session_id.encode("utf-8"),
info=b"agents.session-store.hkdf.v1",
)
derived = hkdf.derive(master_key_bytes)
return Fernet(base64.urlsafe_b64encode(derived))


def _to_json_bytes(obj: Any) -> bytes:
return json.dumps(obj, ensure_ascii=False, separators=(",", ":"), default=str).encode("utf-8")


def _from_json_bytes(data: bytes) -> Any:
return json.loads(data.decode("utf-8"))


def _is_encrypted_envelope(item: object) -> TypeGuard[EncryptedEnvelope]:
"""Type guard to check if an item is an encrypted envelope."""
return (
isinstance(item, dict)
and item.get("__enc__") == 1
and "payload" in item
and "kid" in item
and "v" in item
)


class EncryptedSession(SessionABC):
"""Encrypted wrapper for Session implementations with TTL-based expiration.

This class wraps any SessionABC implementation to provide transparent
encryption/decryption of stored items using Fernet encryption with
per-session key derivation and automatic expiration of old data.

When items expire (exceed TTL), they are silently skipped during retrieval.

Note: Expired tokens are rejected based on the system clock of the application server.
To avoid valid tokens being rejected due to clock drift, ensure all servers in
your environment are synchronized using NTP.
"""

def __init__(
self,
session_id: str,
underlying_session: SessionABC,
encryption_key: str,
ttl: int = 600,
):
"""
Args:
session_id: ID for this session
underlying_session: The real session store (e.g. SQLiteSession, SQLAlchemySession)
encryption_key: Master key (Fernet key or raw secret)
ttl: Token time-to-live in seconds (default 10 min)
"""
self.session_id = session_id
self.underlying_session = underlying_session
self.ttl = ttl

master = _ensure_fernet_key_bytes(encryption_key)
self.cipher = _derive_session_fernet_key(master, session_id)
self._kid = "hkdf-v1"
self._ver = 1

def __getattr__(self, name):
return getattr(self.underlying_session, name)

def _wrap(self, item: TResponseInputItem) -> EncryptedEnvelope:
if isinstance(item, dict):
payload = item
elif hasattr(item, "model_dump"):
payload = item.model_dump()
elif hasattr(item, "__dict__"):
payload = item.__dict__
else:
payload = dict(item)

token = self.cipher.encrypt(_to_json_bytes(payload)).decode("utf-8")
return {"__enc__": 1, "v": self._ver, "kid": self._kid, "payload": token}

def _unwrap(self, item: TResponseInputItem | EncryptedEnvelope) -> TResponseInputItem | None:
if not _is_encrypted_envelope(item):
return cast(TResponseInputItem, item)

try:
token = item["payload"].encode("utf-8")
plaintext = self.cipher.decrypt(token, ttl=self.ttl)
return cast(TResponseInputItem, _from_json_bytes(plaintext))
except (InvalidToken, KeyError):
return None

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
encrypted_items = await self.underlying_session.get_items(limit)
valid_items: list[TResponseInputItem] = []
for enc in encrypted_items:
item = self._unwrap(enc)
if item is not None:
valid_items.append(item)
return valid_items

async def add_items(self, items: list[TResponseInputItem]) -> None:
wrapped: list[EncryptedEnvelope] = [self._wrap(it) for it in items]
await self.underlying_session.add_items(cast(list[TResponseInputItem], wrapped))

async def pop_item(self) -> TResponseInputItem | None:
while True:
enc = await self.underlying_session.pop_item()
if not enc:
return None
item = self._unwrap(enc)
if item is not None:
return item

async def clear_session(self) -> None:
await self.underlying_session.clear_session()
Loading