Skip to content

Commit 2c4bd22

Browse files
committed
fix: Improve type safety in Redis session tests
Replace type ignores with proper type casting and safe content access: - Use cast() for FakeRedis Redis type compatibility - Replace direct dict access with .get() methods - Add _safe_rpush helper for async/sync operation handling - Maintain mypy effectiveness without defeating type checking
1 parent 73a2683 commit 2c4bd22

File tree

1 file changed

+31
-18
lines changed

1 file changed

+31
-18
lines changed

tests/extensions/memory/test_redis_session.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from typing import cast
4+
35
import pytest
46

57
pytest.importorskip("redis") # Skip tests if Redis is not installed
@@ -15,18 +17,28 @@
1517
# Try to use fakeredis for in-memory testing, fall back to real Redis if not available
1618
try:
1719
import fakeredis.aioredis
20+
from redis.asyncio import Redis
1821

19-
fake_redis = fakeredis.aioredis.FakeRedis()
22+
# Use the actual Redis type annotation, but cast the FakeRedis implementation
23+
fake_redis_instance = fakeredis.aioredis.FakeRedis()
24+
fake_redis: Redis = cast("Redis", fake_redis_instance)
2025
USE_FAKE_REDIS = True
2126
except ImportError:
22-
fake_redis = None
27+
fake_redis = None # type: ignore[assignment]
2328
USE_FAKE_REDIS = False
2429

2530
if not USE_FAKE_REDIS:
2631
# Fallback to real Redis for tests that need it
2732
REDIS_URL = "redis://localhost:6379/15" # Using database 15 for tests
2833

2934

35+
async def _safe_rpush(client: Redis, key: str, value: str) -> None:
36+
"""Safely handle rpush operations that might be sync or async in fakeredis."""
37+
result = client.rpush(key, value)
38+
if hasattr(result, "__await__"):
39+
await result
40+
41+
3042
@pytest.fixture
3143
def agent() -> Agent:
3244
"""Fixture for a basic agent with a fake model."""
@@ -560,18 +572,18 @@ async def test_decode_responses_client_compatibility():
560572
# get_items should work with string responses
561573
retrieved = await session.get_items()
562574
assert len(retrieved) == 2
563-
assert retrieved[0]["content"] == "Hello with decoded responses"
564-
assert retrieved[1]["content"] == "Response with unicode: 🚀"
575+
assert retrieved[0].get("content") == "Hello with decoded responses"
576+
assert retrieved[1].get("content") == "Response with unicode: 🚀"
565577

566578
# pop_item should also work with string responses
567579
popped = await session.pop_item()
568580
assert popped is not None
569-
assert popped["content"] == "Response with unicode: 🚀"
581+
assert popped.get("content") == "Response with unicode: 🚀"
570582

571583
# Verify one item remains
572584
remaining = await session.get_items()
573585
assert len(remaining) == 1
574-
assert remaining[0]["content"] == "Hello with decoded responses"
586+
assert remaining[0].get("content") == "Hello with decoded responses"
575587

576588
finally:
577589
try:
@@ -614,13 +626,13 @@ async def test_real_redis_decode_responses_compatibility():
614626
# Should work even though Redis returns strings instead of bytes
615627
retrieved = await session.get_items()
616628
assert len(retrieved) == 2
617-
assert retrieved[0]["content"] == "Real Redis with decode_responses=True"
618-
assert retrieved[1]["content"] == "Unicode test: 🎯"
629+
assert retrieved[0].get("content") == "Real Redis with decode_responses=True"
630+
assert retrieved[1].get("content") == "Unicode test: 🎯"
619631

620632
# pop_item should also work
621633
popped = await session.pop_item()
622634
assert popped is not None
623-
assert popped["content"] == "Unicode test: 🎯"
635+
assert popped.get("content") == "Unicode test: 🎯"
624636

625637
finally:
626638
try:
@@ -683,37 +695,37 @@ async def test_corrupted_data_handling():
683695
# Inject corrupted data directly into Redis
684696
messages_key = "test:corruption_test:messages"
685697

686-
# Add invalid JSON
687-
await fake_redis.rpush(messages_key, "invalid json data")
688-
await fake_redis.rpush(messages_key, "{incomplete json")
698+
# Add invalid JSON directly using the typed Redis client
699+
await _safe_rpush(fake_redis, messages_key, "invalid json data")
700+
await _safe_rpush(fake_redis, messages_key, "{incomplete json")
689701

690702
# get_items should skip corrupted data and return valid items
691703
items = await session.get_items()
692704
assert len(items) == 1 # Only the original valid item
693705

694706
# Now add a properly formatted valid item using the session's serialization
695-
valid_item = {"role": "user", "content": "valid after corruption"}
707+
valid_item: TResponseInputItem = {"role": "user", "content": "valid after corruption"}
696708
await session.add_items([valid_item])
697709

698710
# Should now have 2 valid items (corrupted ones skipped)
699711
items = await session.get_items()
700712
assert len(items) == 2
701-
assert items[0]["content"] == "valid message"
702-
assert items[1]["content"] == "valid after corruption"
713+
assert items[0].get("content") == "valid message"
714+
assert items[1].get("content") == "valid after corruption"
703715

704716
# Test pop_item with corrupted data at the end
705-
await fake_redis.rpush(messages_key, "corrupted at end")
717+
await _safe_rpush(fake_redis, messages_key, "corrupted at end")
706718

707719
# The corrupted item should be handled gracefully
708720
# Since it's at the end, pop_item will encounter it first and return None
709721
# But first, let's pop the valid items to get to the corrupted one
710722
popped1 = await session.pop_item()
711723
assert popped1 is not None
712-
assert popped1["content"] == "valid after corruption"
724+
assert popped1.get("content") == "valid after corruption"
713725

714726
popped2 = await session.pop_item()
715727
assert popped2 is not None
716-
assert popped2["content"] == "valid message"
728+
assert popped2.get("content") == "valid message"
717729

718730
# Now we should hit the corrupted data - this should gracefully handle it
719731
# by returning None (and removing the corrupted item)
@@ -755,6 +767,7 @@ async def test_close_method_coverage():
755767

756768
# Test 1: External client (should NOT be closed)
757769
external_client = fake_redis
770+
assert external_client is not None # Type assertion for mypy
758771
session1 = RedisSession(
759772
session_id="close_test_1",
760773
redis_client=external_client,

0 commit comments

Comments
 (0)