Skip to content

Commit e26b1ee

Browse files
committed
feat: enhance Redis session with client ownership and decode_responses compatibility
- Add Redis client ownership tracking with _owns_client attribute - Implement conditional close() method to prevent closing shared clients - Fix decode_responses compatibility in get_items() and pop_item() methods - Handle both bytes (default) and string (decode_responses=True) Redis responses - Add comprehensive test coverage for client ownership scenarios - Add tests for decode_responses=True client compatibility with both fakeredis and real Redis - Ensure proper resource management and thread-safe operations - Maintain backward compatibility while fixing production edge cases This addresses Redis client lifecycle management issues and ensures compatibility with common Redis client configurations used in production.
1 parent 8d5ba83 commit e26b1ee

File tree

2 files changed

+194
-5
lines changed

2 files changed

+194
-5
lines changed

src/agents/extensions/memory/redis_session.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(
6565
self._key_prefix = key_prefix
6666
self._ttl = ttl
6767
self._lock = asyncio.Lock()
68+
self._owns_client = False # Track if we own the Redis client
6869

6970
# Redis key patterns
7071
self._session_key = f"{self._key_prefix}:{self.session_id}"
@@ -101,7 +102,9 @@ def from_url(
101102
redis_kwargs.setdefault("ssl", True)
102103

103104
redis_client = redis.from_url(url, **redis_kwargs)
104-
return cls(session_id, redis_client=redis_client, **kwargs)
105+
session = cls(session_id, redis_client=redis_client, **kwargs)
106+
session._owns_client = True # We created the client, so we own it
107+
return session
105108

106109
async def _serialize_item(self, item: TResponseInputItem) -> str:
107110
"""Serialize an item to JSON string. Can be overridden by subclasses."""
@@ -152,7 +155,11 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
152155
items: list[TResponseInputItem] = []
153156
for raw_msg in raw_messages:
154157
try:
155-
msg_str = raw_msg.decode("utf-8")
158+
# Handle both bytes (default) and str (decode_responses=True) Redis clients
159+
if isinstance(raw_msg, bytes):
160+
msg_str = raw_msg.decode("utf-8")
161+
else:
162+
msg_str = raw_msg # Already a string
156163
item = await self._deserialize_item(msg_str)
157164
items.append(item)
158165
except (json.JSONDecodeError, UnicodeDecodeError):
@@ -217,7 +224,11 @@ async def pop_item(self) -> TResponseInputItem | None:
217224
return None
218225

219226
try:
220-
msg_str = raw_msg.decode("utf-8")
227+
# Handle both bytes (default) and str (decode_responses=True) Redis clients
228+
if isinstance(raw_msg, bytes):
229+
msg_str = raw_msg.decode("utf-8")
230+
else:
231+
msg_str = raw_msg # Already a string
221232
return await self._deserialize_item(msg_str)
222233
except (json.JSONDecodeError, UnicodeDecodeError):
223234
# Return None for corrupted messages (already removed)
@@ -234,8 +245,14 @@ async def clear_session(self) -> None:
234245
)
235246

236247
async def close(self) -> None:
237-
"""Close the Redis connection."""
238-
await self._redis.aclose()
248+
"""Close the Redis connection.
249+
250+
Only closes the connection if this session owns the Redis client
251+
(i.e., created via from_url). If the client was injected externally,
252+
the caller is responsible for managing its lifecycle.
253+
"""
254+
if self._owns_client:
255+
await self._redis.aclose()
239256

240257
async def ping(self) -> bool:
241258
"""Test Redis connectivity.

tests/extensions/memory/test_redis_session.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,3 +450,175 @@ async def test_key_prefix_isolation():
450450
pass # Ignore cleanup errors
451451
await session1.close()
452452
await session2.close()
453+
454+
455+
async def test_external_client_not_closed():
456+
"""Test that external Redis clients are not closed when session.close() is called."""
457+
if not USE_FAKE_REDIS:
458+
pytest.skip("This test requires fakeredis for client state verification")
459+
460+
# Create a shared Redis client
461+
shared_client = fake_redis
462+
463+
# Create session with external client
464+
session = RedisSession(
465+
session_id="external_client_test",
466+
redis_client=shared_client,
467+
key_prefix="test:",
468+
)
469+
470+
try:
471+
# Add some data to verify the client is working
472+
await session.add_items([{"role": "user", "content": "test message"}])
473+
items = await session.get_items()
474+
assert len(items) == 1
475+
476+
# Verify client is working before close
477+
assert await shared_client.ping() is True
478+
479+
# Close the session
480+
await session.close()
481+
482+
# Verify the shared client is still usable after session.close()
483+
# This would fail if we incorrectly closed the external client
484+
assert await shared_client.ping() is True
485+
486+
# Should still be able to use the client for other operations
487+
await shared_client.set("test_key", "test_value")
488+
value = await shared_client.get("test_key")
489+
assert value.decode("utf-8") == "test_value"
490+
491+
finally:
492+
# Clean up
493+
try:
494+
await session.clear_session()
495+
except Exception:
496+
pass # Ignore cleanup errors if connection is already closed
497+
498+
499+
async def test_internal_client_ownership():
500+
"""Test that clients created via from_url are properly managed."""
501+
if USE_FAKE_REDIS:
502+
pytest.skip("This test requires real Redis to test from_url behavior")
503+
504+
# Create session using from_url (internal client)
505+
session = RedisSession.from_url("internal_client_test", url="redis://localhost:6379/15")
506+
507+
try:
508+
if not await session.ping():
509+
pytest.skip("Redis server not available")
510+
511+
# Add some data
512+
await session.add_items([{"role": "user", "content": "test message"}])
513+
items = await session.get_items()
514+
assert len(items) == 1
515+
516+
# The session should properly manage its own client
517+
# Note: We can't easily test that the client is actually closed
518+
# without risking breaking the test, but we can verify the
519+
# session was created with internal client ownership
520+
assert hasattr(session, "_owns_client")
521+
assert session._owns_client is True
522+
523+
finally:
524+
# This should properly close the internal client
525+
await session.close()
526+
527+
528+
async def test_decode_responses_client_compatibility():
529+
"""Test that RedisSession works with Redis clients configured with decode_responses=True."""
530+
if not USE_FAKE_REDIS:
531+
pytest.skip("This test requires fakeredis for client configuration testing")
532+
533+
# Create a Redis client with decode_responses=True
534+
import fakeredis.aioredis
535+
536+
decoded_client = fakeredis.aioredis.FakeRedis(decode_responses=True)
537+
538+
# Create session with the decoded client
539+
session = RedisSession(
540+
session_id="decode_test",
541+
redis_client=decoded_client,
542+
key_prefix="test:",
543+
)
544+
545+
try:
546+
# Test that we can add and retrieve items even when Redis returns strings
547+
test_items: list[TResponseInputItem] = [
548+
{"role": "user", "content": "Hello with decoded responses"},
549+
{"role": "assistant", "content": "Response with unicode: 🚀"},
550+
]
551+
552+
await session.add_items(test_items)
553+
554+
# get_items should work with string responses
555+
retrieved = await session.get_items()
556+
assert len(retrieved) == 2
557+
assert retrieved[0]["content"] == "Hello with decoded responses"
558+
assert retrieved[1]["content"] == "Response with unicode: 🚀"
559+
560+
# pop_item should also work with string responses
561+
popped = await session.pop_item()
562+
assert popped is not None
563+
assert popped["content"] == "Response with unicode: 🚀"
564+
565+
# Verify one item remains
566+
remaining = await session.get_items()
567+
assert len(remaining) == 1
568+
assert remaining[0]["content"] == "Hello with decoded responses"
569+
570+
finally:
571+
try:
572+
await session.clear_session()
573+
except Exception:
574+
pass # Ignore cleanup errors
575+
await session.close()
576+
577+
578+
async def test_real_redis_decode_responses_compatibility():
579+
"""Test RedisSession with a real Redis client configured with decode_responses=True."""
580+
if USE_FAKE_REDIS:
581+
pytest.skip("This test requires real Redis to test decode_responses behavior")
582+
583+
import redis.asyncio as redis
584+
585+
# Create a Redis client with decode_responses=True
586+
decoded_client = redis.Redis.from_url("redis://localhost:6379/15", decode_responses=True)
587+
588+
session = RedisSession(
589+
session_id="real_decode_test",
590+
redis_client=decoded_client,
591+
key_prefix="test:",
592+
)
593+
594+
try:
595+
if not await session.ping():
596+
pytest.skip("Redis server not available")
597+
598+
await session.clear_session()
599+
600+
# Test with decode_responses=True client
601+
test_items: list[TResponseInputItem] = [
602+
{"role": "user", "content": "Real Redis with decode_responses=True"},
603+
{"role": "assistant", "content": "Unicode test: 🎯"},
604+
]
605+
606+
await session.add_items(test_items)
607+
608+
# Should work even though Redis returns strings instead of bytes
609+
retrieved = await session.get_items()
610+
assert len(retrieved) == 2
611+
assert retrieved[0]["content"] == "Real Redis with decode_responses=True"
612+
assert retrieved[1]["content"] == "Unicode test: 🎯"
613+
614+
# pop_item should also work
615+
popped = await session.pop_item()
616+
assert popped is not None
617+
assert popped["content"] == "Unicode test: 🎯"
618+
619+
finally:
620+
try:
621+
await session.clear_session()
622+
except Exception:
623+
pass
624+
await session.close()

0 commit comments

Comments
 (0)