Skip to content

Commit 63ea9e4

Browse files
committed
fix(tests): corrected extension import logic, added enhancementas as per suggestions
1 parent 56eb2a6 commit 63ea9e4

File tree

5 files changed

+44
-41
lines changed

5 files changed

+44
-41
lines changed

examples/memory/encrypted_session_example.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
"""
88

99
import asyncio
10+
from typing import cast
1011

1112
from agents import Agent, Runner, SQLiteSession
1213
from agents.extensions.memory import EncryptedSession
14+
from agents.extensions.memory.encrypt_session import EncryptedEnvelope
1315

1416

1517
async def main():
@@ -27,7 +29,7 @@ async def main():
2729
session = EncryptedSession(
2830
session_id=session_id,
2931
underlying_session=underlying_session,
30-
encryption_key="my-secret-encryption-key", # In production, use ENCRYPTION_KEY env var
32+
encryption_key="my-secret-encryption-key",
3133
ttl=3600, # 1 hour TTL for messages
3234
)
3335

@@ -89,8 +91,10 @@ async def main():
8991
print("Raw encrypted items in underlying storage:")
9092
for i, item in enumerate(raw_items, 1):
9193
if isinstance(item, dict) and item.get("__enc__") == 1:
94+
enc_item = cast(EncryptedEnvelope, item)
9295
print(
93-
f" {i}. Encrypted envelope: __enc__={item['__enc__']}, payload length={len(item.get('payload', ''))}"
96+
f" {i}. Encrypted envelope: __enc__={enc_item['__enc__']}, "
97+
f"payload length={len(enc_item['payload'])}"
9498
)
9599
else:
96100
print(f" {i}. Unencrypted item: {item}")

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ dev = [
6666
"eval-type-backport>=0.2.2",
6767
"fastapi >= 0.110.0, <1",
6868
"aiosqlite>=0.21.0",
69+
"types-cryptography",
70+
"cryptography>=45.0, <46",
6971
]
7072

7173
[tool.uv.workspace]

src/agents/extensions/memory/__init__.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,38 @@
88

99
from __future__ import annotations
1010

11+
from typing import Any
12+
1113
from .encrypt_session import EncryptedSession # noqa: F401
1214
from .sqlalchemy_session import SQLAlchemySession # noqa: F401
1315

1416
__all__: list[str] = [
1517
"EncryptedSession",
1618
"SQLAlchemySession",
1719
]
20+
21+
22+
def __getattr__(name: str) -> Any:
23+
if name == "EncryptedSession":
24+
try:
25+
from .encrypt_session import EncryptedSession
26+
27+
return EncryptedSession
28+
except ModuleNotFoundError as e:
29+
raise ImportError(
30+
"EncryptedSession requires the 'cryptography' extra. "
31+
"Install it with: pip install openai-agents[encrypt]"
32+
) from e
33+
34+
if name == "SQLAlchemySession":
35+
try:
36+
from .sqlalchemy_session import SQLAlchemySession
37+
38+
return SQLAlchemySession
39+
except ModuleNotFoundError as e:
40+
raise ImportError(
41+
"SQLAlchemySession requires the 'sqlalchemy' extra. "
42+
"Install it with: pip install openai-agents[sqlalchemy]"
43+
) from e
44+
45+
raise AttributeError(f"module {__name__} has no attribute {name}")

src/agents/extensions/memory/encrypt_session.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
session = EncryptedSession(
1919
session_id="user-123",
2020
underlying_session=underlying_session,
21-
encryption_key="your-encryption-key", # or use ENCRYPTION_KEY env var
21+
encryption_key="your-encryption-key",
2222
ttl=600, # 10 minutes
2323
)
2424
@@ -29,7 +29,6 @@
2929

3030
import base64
3131
import json
32-
import os
3332
from typing import Any, Literal, TypedDict, TypeGuard, cast
3433

3534
from cryptography.fernet import Fernet, InvalidToken
@@ -55,7 +54,7 @@ def _ensure_fernet_key_bytes(master_key: str) -> bytes:
5554
Returns raw bytes suitable for HKDF input.
5655
"""
5756
if not master_key:
58-
raise ValueError("ENCRYPTION_KEY missing; required for EncryptedSession.")
57+
raise ValueError("encryption_key not set; required for EncryptedSession.")
5958
try:
6059
key_bytes = base64.urlsafe_b64decode(master_key)
6160
if len(key_bytes) == 32:
@@ -113,7 +112,7 @@ def __init__(
113112
self,
114113
session_id: str,
115114
underlying_session: SessionABC,
116-
encryption_key: str | None = os.getenv("ENCRYPTION_KEY", None),
115+
encryption_key: str,
117116
ttl: int = 600,
118117
):
119118
"""
@@ -126,8 +125,7 @@ def __init__(
126125
self.session_id = session_id
127126
self.underlying_session = underlying_session
128127
self.ttl = ttl
129-
if encryption_key is None:
130-
raise ValueError("ENCRYPTION_KEY missing; required for EncryptedSession.")
128+
131129
master = _ensure_fernet_key_bytes(encryption_key)
132130
self.cipher = _derive_session_fernet_key(master, session_id)
133131
self._kid = "hkdf-v1"

tests/extensions/memory/test_encrypt_session.py

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from __future__ import annotations
22

3-
import os
43
import tempfile
54
import time
65
from pathlib import Path
7-
from unittest.mock import patch
86

97
import pytest
8+
9+
pytest.importorskip("cryptography") # Skip tests if cryptography is not installed
10+
1011
from cryptography.fernet import Fernet
1112

1213
from agents import Agent, Runner, SQLiteSession, TResponseInputItem
@@ -27,7 +28,7 @@ def agent() -> Agent:
2728
@pytest.fixture
2829
def encryption_key() -> str:
2930
"""Fixture for a valid Fernet encryption key."""
30-
return Fernet.generate_key().decode("utf-8")
31+
return str(Fernet.generate_key().decode("utf-8"))
3132

3233

3334
@pytest.fixture
@@ -229,36 +230,6 @@ async def test_encrypted_session_pop_mixed_expired_valid(
229230
underlying_session.close()
230231

231232

232-
async def test_encrypted_session_env_key(underlying_session: SQLiteSession):
233-
"""Test encryption key from environment variable."""
234-
key = Fernet.generate_key().decode("utf-8")
235-
236-
with patch.dict(os.environ, {"ENCRYPTION_KEY": key}):
237-
session = EncryptedSession(
238-
session_id="test_session",
239-
underlying_session=underlying_session,
240-
)
241-
242-
await session.add_items([{"role": "user", "content": "Test"}])
243-
items = await session.get_items()
244-
assert len(items) == 1
245-
assert items[0].get("content") == "Test"
246-
247-
underlying_session.close()
248-
249-
250-
async def test_encrypted_session_missing_key(underlying_session: SQLiteSession):
251-
"""Test error handling for missing encryption key."""
252-
with pytest.raises(ValueError, match="ENCRYPTION_KEY missing"):
253-
EncryptedSession(
254-
session_id="test_session",
255-
underlying_session=underlying_session,
256-
encryption_key=None,
257-
)
258-
259-
underlying_session.close()
260-
261-
262233
async def test_encrypted_session_raw_string_key(underlying_session: SQLiteSession):
263234
"""Test using raw string as encryption key (not base64)."""
264235
session = EncryptedSession(

0 commit comments

Comments
 (0)