|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +from typing import cast |
| 4 | + |
3 | 5 | import pytest
|
4 | 6 |
|
5 | 7 | pytest.importorskip("redis") # Skip tests if Redis is not installed
|
|
15 | 17 | # Try to use fakeredis for in-memory testing, fall back to real Redis if not available
|
16 | 18 | try:
|
17 | 19 | import fakeredis.aioredis
|
| 20 | + from redis.asyncio import Redis |
18 | 21 |
|
19 |
| - fake_redis = fakeredis.aioredis.FakeRedis() |
| 22 | + # Use the actual Redis type annotation, but cast the FakeRedis implementation |
| 23 | + fake_redis_instance = fakeredis.aioredis.FakeRedis() |
| 24 | + fake_redis: Redis = cast("Redis", fake_redis_instance) |
20 | 25 | USE_FAKE_REDIS = True
|
21 | 26 | except ImportError:
|
22 |
| - fake_redis = None |
| 27 | + fake_redis = None # type: ignore[assignment] |
23 | 28 | USE_FAKE_REDIS = False
|
24 | 29 |
|
25 | 30 | if not USE_FAKE_REDIS:
|
26 | 31 | # Fallback to real Redis for tests that need it
|
27 | 32 | REDIS_URL = "redis://localhost:6379/15" # Using database 15 for tests
|
28 | 33 |
|
29 | 34 |
|
| 35 | +async def _safe_rpush(client: Redis, key: str, value: str) -> None: |
| 36 | + """Safely handle rpush operations that might be sync or async in fakeredis.""" |
| 37 | + result = client.rpush(key, value) |
| 38 | + if hasattr(result, "__await__"): |
| 39 | + await result |
| 40 | + |
| 41 | + |
30 | 42 | @pytest.fixture
|
31 | 43 | def agent() -> Agent:
|
32 | 44 | """Fixture for a basic agent with a fake model."""
|
@@ -560,18 +572,18 @@ async def test_decode_responses_client_compatibility():
|
560 | 572 | # get_items should work with string responses
|
561 | 573 | retrieved = await session.get_items()
|
562 | 574 | assert len(retrieved) == 2
|
563 |
| - assert retrieved[0]["content"] == "Hello with decoded responses" |
564 |
| - assert retrieved[1]["content"] == "Response with unicode: 🚀" |
| 575 | + assert retrieved[0].get("content") == "Hello with decoded responses" |
| 576 | + assert retrieved[1].get("content") == "Response with unicode: 🚀" |
565 | 577 |
|
566 | 578 | # pop_item should also work with string responses
|
567 | 579 | popped = await session.pop_item()
|
568 | 580 | assert popped is not None
|
569 |
| - assert popped["content"] == "Response with unicode: 🚀" |
| 581 | + assert popped.get("content") == "Response with unicode: 🚀" |
570 | 582 |
|
571 | 583 | # Verify one item remains
|
572 | 584 | remaining = await session.get_items()
|
573 | 585 | assert len(remaining) == 1
|
574 |
| - assert remaining[0]["content"] == "Hello with decoded responses" |
| 586 | + assert remaining[0].get("content") == "Hello with decoded responses" |
575 | 587 |
|
576 | 588 | finally:
|
577 | 589 | try:
|
@@ -614,13 +626,13 @@ async def test_real_redis_decode_responses_compatibility():
|
614 | 626 | # Should work even though Redis returns strings instead of bytes
|
615 | 627 | retrieved = await session.get_items()
|
616 | 628 | assert len(retrieved) == 2
|
617 |
| - assert retrieved[0]["content"] == "Real Redis with decode_responses=True" |
618 |
| - assert retrieved[1]["content"] == "Unicode test: 🎯" |
| 629 | + assert retrieved[0].get("content") == "Real Redis with decode_responses=True" |
| 630 | + assert retrieved[1].get("content") == "Unicode test: 🎯" |
619 | 631 |
|
620 | 632 | # pop_item should also work
|
621 | 633 | popped = await session.pop_item()
|
622 | 634 | assert popped is not None
|
623 |
| - assert popped["content"] == "Unicode test: 🎯" |
| 635 | + assert popped.get("content") == "Unicode test: 🎯" |
624 | 636 |
|
625 | 637 | finally:
|
626 | 638 | try:
|
@@ -683,37 +695,37 @@ async def test_corrupted_data_handling():
|
683 | 695 | # Inject corrupted data directly into Redis
|
684 | 696 | messages_key = "test:corruption_test:messages"
|
685 | 697 |
|
686 |
| - # Add invalid JSON |
687 |
| - await fake_redis.rpush(messages_key, "invalid json data") |
688 |
| - await fake_redis.rpush(messages_key, "{incomplete json") |
| 698 | + # Add invalid JSON directly using the typed Redis client |
| 699 | + await _safe_rpush(fake_redis, messages_key, "invalid json data") |
| 700 | + await _safe_rpush(fake_redis, messages_key, "{incomplete json") |
689 | 701 |
|
690 | 702 | # get_items should skip corrupted data and return valid items
|
691 | 703 | items = await session.get_items()
|
692 | 704 | assert len(items) == 1 # Only the original valid item
|
693 | 705 |
|
694 | 706 | # Now add a properly formatted valid item using the session's serialization
|
695 |
| - valid_item = {"role": "user", "content": "valid after corruption"} |
| 707 | + valid_item: TResponseInputItem = {"role": "user", "content": "valid after corruption"} |
696 | 708 | await session.add_items([valid_item])
|
697 | 709 |
|
698 | 710 | # Should now have 2 valid items (corrupted ones skipped)
|
699 | 711 | items = await session.get_items()
|
700 | 712 | assert len(items) == 2
|
701 |
| - assert items[0]["content"] == "valid message" |
702 |
| - assert items[1]["content"] == "valid after corruption" |
| 713 | + assert items[0].get("content") == "valid message" |
| 714 | + assert items[1].get("content") == "valid after corruption" |
703 | 715 |
|
704 | 716 | # Test pop_item with corrupted data at the end
|
705 |
| - await fake_redis.rpush(messages_key, "corrupted at end") |
| 717 | + await _safe_rpush(fake_redis, messages_key, "corrupted at end") |
706 | 718 |
|
707 | 719 | # The corrupted item should be handled gracefully
|
708 | 720 | # Since it's at the end, pop_item will encounter it first and return None
|
709 | 721 | # But first, let's pop the valid items to get to the corrupted one
|
710 | 722 | popped1 = await session.pop_item()
|
711 | 723 | assert popped1 is not None
|
712 |
| - assert popped1["content"] == "valid after corruption" |
| 724 | + assert popped1.get("content") == "valid after corruption" |
713 | 725 |
|
714 | 726 | popped2 = await session.pop_item()
|
715 | 727 | assert popped2 is not None
|
716 |
| - assert popped2["content"] == "valid message" |
| 728 | + assert popped2.get("content") == "valid message" |
717 | 729 |
|
718 | 730 | # Now we should hit the corrupted data - this should gracefully handle it
|
719 | 731 | # by returning None (and removing the corrupted item)
|
@@ -755,6 +767,7 @@ async def test_close_method_coverage():
|
755 | 767 |
|
756 | 768 | # Test 1: External client (should NOT be closed)
|
757 | 769 | external_client = fake_redis
|
| 770 | + assert external_client is not None # Type assertion for mypy |
758 | 771 | session1 = RedisSession(
|
759 | 772 | session_id="close_test_1",
|
760 | 773 | redis_client=external_client,
|
|
0 commit comments