-
Notifications
You must be signed in to change notification settings - Fork 2.7k
Add encryption support using cryptography to Sessions implementation #1674
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+726
−17
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
4aed8b2
feat: add encrypted session support with TTL expiration
maxmekiska 7cf7124
docs: add encrypted session usage example
maxmekiska 293c650
build: add cryptography dependency for encrypted sessions
maxmekiska a4a2244
Merge branch 'main' into encrypt-session
seratch 5083c86
feat: add session delegation and resolve Python 3.9 compatibility
maxmekiska d3a359c
Merge branch 'encrypt-session' of github.com:maxmekiska/openai-agents…
maxmekiska 09e30c1
Merge branch 'main' into encrypt-session
maxmekiska 5adcb18
Merge branch 'main' into encrypt-session
maxmekiska File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
""" | ||
Example demonstrating encrypted session memory functionality. | ||
This example shows how to use encrypted session memory to maintain conversation history | ||
across multiple agent runs with automatic encryption and TTL-based expiration. | ||
The EncryptedSession wrapper provides transparent encryption over any underlying session. | ||
""" | ||
|
||
import asyncio | ||
from typing import cast | ||
|
||
from agents import Agent, Runner, SQLiteSession | ||
from agents.extensions.memory import EncryptedSession | ||
from agents.extensions.memory.encrypt_session import EncryptedEnvelope | ||
|
||
|
||
async def main(): | ||
# Create an agent | ||
agent = Agent( | ||
name="Assistant", | ||
instructions="Reply very concisely.", | ||
) | ||
|
||
# Create an underlying session (SQLiteSession in this example) | ||
session_id = "conversation_123" | ||
underlying_session = SQLiteSession(session_id) | ||
|
||
# Wrap with encrypted session for automatic encryption and TTL | ||
session = EncryptedSession( | ||
session_id=session_id, | ||
underlying_session=underlying_session, | ||
encryption_key="my-secret-encryption-key", | ||
ttl=3600, # 1 hour TTL for messages | ||
) | ||
|
||
print("=== Encrypted Session Example ===") | ||
print("The agent will remember previous messages automatically with encryption.\n") | ||
|
||
# First turn | ||
print("First turn:") | ||
print("User: What city is the Golden Gate Bridge in?") | ||
result = await Runner.run( | ||
agent, | ||
"What city is the Golden Gate Bridge in?", | ||
session=session, | ||
) | ||
print(f"Assistant: {result.final_output}") | ||
print() | ||
|
||
# Second turn - the agent will remember the previous conversation | ||
print("Second turn:") | ||
print("User: What state is it in?") | ||
result = await Runner.run(agent, "What state is it in?", session=session) | ||
print(f"Assistant: {result.final_output}") | ||
print() | ||
|
||
# Third turn - continuing the conversation | ||
print("Third turn:") | ||
print("User: What's the population of that state?") | ||
result = await Runner.run( | ||
agent, | ||
"What's the population of that state?", | ||
session=session, | ||
) | ||
print(f"Assistant: {result.final_output}") | ||
print() | ||
|
||
print("=== Conversation Complete ===") | ||
print("Notice how the agent remembered the context from previous turns!") | ||
print("All conversation history was automatically encrypted and stored securely.") | ||
|
||
# Demonstrate the limit parameter - get only the latest 2 items | ||
print("\n=== Latest Items Demo ===") | ||
latest_items = await session.get_items(limit=2) | ||
print("Latest 2 items (automatically decrypted):") | ||
for i, msg in enumerate(latest_items, 1): | ||
role = msg.get("role", "unknown") | ||
content = msg.get("content", "") | ||
print(f" {i}. {role}: {content}") | ||
|
||
print(f"\nFetched {len(latest_items)} out of total conversation history.") | ||
|
||
# Get all items to show the difference | ||
all_items = await session.get_items() | ||
print(f"Total items in session: {len(all_items)}") | ||
|
||
# Show that underlying storage is encrypted | ||
print("\n=== Encryption Demo ===") | ||
print("Checking underlying storage to verify encryption...") | ||
raw_items = await underlying_session.get_items() | ||
print("Raw encrypted items in underlying storage:") | ||
for i, item in enumerate(raw_items, 1): | ||
if isinstance(item, dict) and item.get("__enc__") == 1: | ||
enc_item = cast(EncryptedEnvelope, item) | ||
print( | ||
f" {i}. Encrypted envelope: __enc__={enc_item['__enc__']}, " | ||
f"payload length={len(enc_item['payload'])}" | ||
) | ||
else: | ||
print(f" {i}. Unencrypted item: {item}") | ||
|
||
print(f"\nAll {len(raw_items)} items are stored encrypted with TTL-based expiration.") | ||
|
||
# Clean up | ||
underlying_session.close() | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,42 @@ | ||
"""Session memory backends living in the extensions namespace. | ||
This package contains optional, production-grade session implementations that | ||
introduce extra third-party dependencies (database drivers, ORMs, etc.). They | ||
conform to the :class:`agents.memory.session.Session` protocol so they can be | ||
used as a drop-in replacement for :class:`agents.memory.session.SQLiteSession`. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
from .sqlalchemy_session import SQLAlchemySession # noqa: F401 | ||
|
||
__all__: list[str] = [ | ||
"SQLAlchemySession", | ||
] | ||
"""Session memory backends living in the extensions namespace. | ||
This package contains optional, production-grade session implementations that | ||
introduce extra third-party dependencies (database drivers, ORMs, etc.). They | ||
conform to the :class:`agents.memory.session.Session` protocol so they can be | ||
used as a drop-in replacement for :class:`agents.memory.session.SQLiteSession`. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
__all__: list[str] = [ | ||
"EncryptedSession", | ||
"SQLAlchemySession", | ||
] | ||
|
||
|
||
def __getattr__(name: str) -> Any: | ||
if name == "EncryptedSession": | ||
try: | ||
from .encrypt_session import EncryptedSession # noqa: F401 | ||
|
||
return EncryptedSession | ||
except ModuleNotFoundError as e: | ||
raise ImportError( | ||
"EncryptedSession requires the 'cryptography' extra. " | ||
"Install it with: pip install openai-agents[encrypt]" | ||
) from e | ||
|
||
if name == "SQLAlchemySession": | ||
try: | ||
from .sqlalchemy_session import SQLAlchemySession # noqa: F401 | ||
|
||
return SQLAlchemySession | ||
except ModuleNotFoundError as e: | ||
raise ImportError( | ||
"SQLAlchemySession requires the 'sqlalchemy' extra. " | ||
"Install it with: pip install openai-agents[sqlalchemy]" | ||
) from e | ||
|
||
raise AttributeError(f"module {__name__} has no attribute {name}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
"""Encrypted Session wrapper for secure conversation storage. | ||
|
||
This module provides transparent encryption for session storage with automatic | ||
expiration of old data. When TTL expires, expired items are silently skipped. | ||
|
||
Usage:: | ||
|
||
from agents.extensions.memory import EncryptedSession, SQLAlchemySession | ||
|
||
# Create underlying session (e.g. SQLAlchemySession) | ||
underlying_session = SQLAlchemySession.from_url( | ||
session_id="user-123", | ||
url="postgresql+asyncpg://app:[email protected]/agents", | ||
create_tables=True, | ||
) | ||
|
||
# Wrap with encryption and TTL-based expiration | ||
session = EncryptedSession( | ||
session_id="user-123", | ||
underlying_session=underlying_session, | ||
encryption_key="your-encryption-key", | ||
ttl=600, # 10 minutes | ||
) | ||
|
||
await Runner.run(agent, "Hello", session=session) | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
import base64 | ||
import json | ||
from typing import Any, cast | ||
|
||
from cryptography.fernet import Fernet, InvalidToken | ||
from cryptography.hazmat.primitives import hashes | ||
from cryptography.hazmat.primitives.kdf.hkdf import HKDF | ||
from typing_extensions import Literal, TypedDict, TypeGuard | ||
|
||
from ...items import TResponseInputItem | ||
from ...memory.session import SessionABC | ||
|
||
|
||
class EncryptedEnvelope(TypedDict): | ||
"""TypedDict for encrypted message envelopes stored in the underlying session.""" | ||
|
||
__enc__: Literal[1] | ||
v: int | ||
kid: str | ||
payload: str | ||
|
||
|
||
def _ensure_fernet_key_bytes(master_key: str) -> bytes: | ||
""" | ||
Accept either a Fernet key (urlsafe-b64, 32 bytes after decode) or a raw string. | ||
Returns raw bytes suitable for HKDF input. | ||
""" | ||
if not master_key: | ||
raise ValueError("encryption_key not set; required for EncryptedSession.") | ||
try: | ||
key_bytes = base64.urlsafe_b64decode(master_key) | ||
if len(key_bytes) == 32: | ||
return key_bytes | ||
except Exception: | ||
pass | ||
return master_key.encode("utf-8") | ||
|
||
|
||
def _derive_session_fernet_key(master_key_bytes: bytes, session_id: str) -> Fernet: | ||
hkdf = HKDF( | ||
algorithm=hashes.SHA256(), | ||
length=32, | ||
salt=session_id.encode("utf-8"), | ||
info=b"agents.session-store.hkdf.v1", | ||
) | ||
derived = hkdf.derive(master_key_bytes) | ||
return Fernet(base64.urlsafe_b64encode(derived)) | ||
|
||
|
||
def _to_json_bytes(obj: Any) -> bytes: | ||
return json.dumps(obj, ensure_ascii=False, separators=(",", ":"), default=str).encode("utf-8") | ||
|
||
|
||
def _from_json_bytes(data: bytes) -> Any: | ||
return json.loads(data.decode("utf-8")) | ||
|
||
|
||
def _is_encrypted_envelope(item: object) -> TypeGuard[EncryptedEnvelope]: | ||
"""Type guard to check if an item is an encrypted envelope.""" | ||
return ( | ||
isinstance(item, dict) | ||
and item.get("__enc__") == 1 | ||
and "payload" in item | ||
and "kid" in item | ||
and "v" in item | ||
) | ||
|
||
|
||
class EncryptedSession(SessionABC): | ||
"""Encrypted wrapper for Session implementations with TTL-based expiration. | ||
|
||
This class wraps any SessionABC implementation to provide transparent | ||
encryption/decryption of stored items using Fernet encryption with | ||
per-session key derivation and automatic expiration of old data. | ||
|
||
When items expire (exceed TTL), they are silently skipped during retrieval. | ||
|
||
Note: Expired tokens are rejected based on the system clock of the application server. | ||
To avoid valid tokens being rejected due to clock drift, ensure all servers in | ||
your environment are synchronized using NTP. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
session_id: str, | ||
underlying_session: SessionABC, | ||
encryption_key: str, | ||
ttl: int = 600, | ||
): | ||
""" | ||
Args: | ||
session_id: ID for this session | ||
underlying_session: The real session store (e.g. SQLiteSession, SQLAlchemySession) | ||
encryption_key: Master key (Fernet key or raw secret) | ||
ttl: Token time-to-live in seconds (default 10 min) | ||
""" | ||
self.session_id = session_id | ||
self.underlying_session = underlying_session | ||
self.ttl = ttl | ||
|
||
master = _ensure_fernet_key_bytes(encryption_key) | ||
self.cipher = _derive_session_fernet_key(master, session_id) | ||
self._kid = "hkdf-v1" | ||
self._ver = 1 | ||
|
||
def __getattr__(self, name): | ||
return getattr(self.underlying_session, name) | ||
|
||
def _wrap(self, item: TResponseInputItem) -> EncryptedEnvelope: | ||
if isinstance(item, dict): | ||
payload = item | ||
elif hasattr(item, "model_dump"): | ||
payload = item.model_dump() | ||
elif hasattr(item, "__dict__"): | ||
payload = item.__dict__ | ||
else: | ||
payload = dict(item) | ||
|
||
token = self.cipher.encrypt(_to_json_bytes(payload)).decode("utf-8") | ||
return {"__enc__": 1, "v": self._ver, "kid": self._kid, "payload": token} | ||
|
||
def _unwrap(self, item: TResponseInputItem | EncryptedEnvelope) -> TResponseInputItem | None: | ||
if not _is_encrypted_envelope(item): | ||
return cast(TResponseInputItem, item) | ||
|
||
try: | ||
token = item["payload"].encode("utf-8") | ||
plaintext = self.cipher.decrypt(token, ttl=self.ttl) | ||
return cast(TResponseInputItem, _from_json_bytes(plaintext)) | ||
except (InvalidToken, KeyError): | ||
return None | ||
|
||
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: | ||
encrypted_items = await self.underlying_session.get_items(limit) | ||
valid_items: list[TResponseInputItem] = [] | ||
for enc in encrypted_items: | ||
item = self._unwrap(enc) | ||
if item is not None: | ||
valid_items.append(item) | ||
return valid_items | ||
|
||
async def add_items(self, items: list[TResponseInputItem]) -> None: | ||
wrapped: list[EncryptedEnvelope] = [self._wrap(it) for it in items] | ||
await self.underlying_session.add_items(cast(list[TResponseInputItem], wrapped)) | ||
|
||
async def pop_item(self) -> TResponseInputItem | None: | ||
while True: | ||
enc = await self.underlying_session.pop_item() | ||
if not enc: | ||
return None | ||
item = self._unwrap(enc) | ||
if item is not None: | ||
return item | ||
|
||
async def clear_session(self) -> None: | ||
await self.underlying_session.clear_session() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.