Skip to content

Commit 3d07d46

Browse files
committed
fix: enhance RedisStreamInjector to support custom Redis client and improve error handling
1 parent 6e0c324 commit 3d07d46

File tree

3 files changed

+54
-45
lines changed

3 files changed

+54
-45
lines changed

src/a2a/server/events/redis_queue_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(
3636
async def add(self, task_id: str, queue: EventQueue) -> None:
3737
"""Add is not supported in distributed Redis setup.
3838
39-
In a distributed environment, we can't reliably add pre-existing queue
39+
In a distributed environment, we can't reliably add preexisting queue
4040
instances. Use create_or_tap() instead to create Redis-backed queues.
4141
"""
4242
raise NotImplementedError(

src/a2a/utils/stream_write/redis_stream_writer.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,35 @@
2626
class RedisStreamInjector:
2727
"""Professional stream injector for A2A framework."""
2828

29-
def __init__(self, redis_url: str = 'redis://localhost:6379/0'):
29+
def __init__(
30+
self,
31+
redis_url: str = 'redis://localhost:6379/0',
32+
redis_client: Any | None = None,
33+
):
3034
"""Initialize the stream injector."""
31-
if Redis is None:
35+
# Allow passing a custom redis client (e.g. a fake in tests).
36+
if Redis is None and redis_client is None:
3237
raise ImportError(
3338
'redis package is required. Install with: pip install redis'
3439
)
3540

3641
self.redis_url = redis_url
37-
self._client = None
38-
self._connected = False
42+
self._client = redis_client
43+
self._connected = redis_client is not None
3944

4045
async def connect(self) -> None:
4146
"""Establish Redis connection."""
4247
if self._connected:
4348
return
4449

4550
try:
46-
self._client = Redis.from_url(self.redis_url)
47-
await self._client.ping()
51+
if self._client is None:
52+
if Redis is None:
53+
raise ImportError(
54+
'redis package is required. Install with: pip install redis'
55+
)
56+
self._client = Redis.from_url(self.redis_url)
57+
await self._client.ping()
4858
self._connected = True
4959
logger.info('Connected to Redis')
5060
except Exception:
@@ -102,7 +112,7 @@ async def _append_to_stream(
102112
raise RuntimeError('Not connected to Redis. Call connect() first.')
103113

104114
stream_key = self._get_stream_key(task_id)
105-
return await self._client.xadd(stream_key, event_data)
115+
return await self._client.xadd(stream_key, event_data) # type: ignore
106116

107117
async def stream_message(
108118
self, context_id: str, task_id: str, message: dict[str, Any] | Message

tests/utils/test_redis_stream_writer.py

Lines changed: 36 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
1-
import json
21
import pytest
32
from unittest.mock import AsyncMock, patch
43

54
from a2a.types import TaskStatusUpdateEvent
65
from a2a.utils.stream_write.redis_stream_writer import RedisStreamInjector
76

87

8+
@pytest.fixture
9+
def mock_redis_client():
10+
"""Fixture providing a mock Redis client."""
11+
client = AsyncMock()
12+
client.xadd = AsyncMock(return_value='123-0')
13+
client.ping = AsyncMock()
14+
client.aclose = AsyncMock()
15+
return client
16+
17+
918
class TestRedisStreamInjector:
1019
"""Test suite for RedisStreamInjector."""
1120

@@ -90,8 +99,7 @@ async def test_disconnect(self):
9099
mock_client = AsyncMock()
91100
mock_client.aclose = AsyncMock()
92101

93-
injector = RedisStreamInjector()
94-
injector._client = mock_client
102+
injector = RedisStreamInjector(redis_client=mock_client)
95103
injector._connected = True
96104

97105
await injector.disconnect()
@@ -103,15 +111,18 @@ async def test_disconnect(self):
103111
@pytest.mark.asyncio
104112
async def test_disconnect_not_connected(self):
105113
"""Test disconnect when not connected."""
106-
injector = RedisStreamInjector()
107-
injector._client = None
114+
mock_client = AsyncMock()
115+
injector = RedisStreamInjector(redis_client=mock_client)
108116
injector._connected = False
109117

110118
await injector.disconnect()
111119

112-
# Should not raise any errors
120+
# Should not call aclose since not connected
121+
mock_client.aclose.assert_not_called()
122+
123+
# Should not raise any errors and client should remain
113124
assert not injector._connected
114-
assert injector._client is None
125+
assert injector._client == mock_client
115126

116127
@pytest.mark.asyncio
117128
async def test_context_manager(self):
@@ -120,57 +131,46 @@ async def test_context_manager(self):
120131
mock_client.ping = AsyncMock()
121132
mock_client.aclose = AsyncMock()
122133

123-
with patch(
124-
'a2a.utils.stream_write.redis_stream_writer.Redis'
125-
) as mock_redis_class:
126-
mock_redis_class.from_url.return_value = mock_client
127-
128-
injector = RedisStreamInjector()
134+
injector = RedisStreamInjector(redis_client=mock_client)
129135

130-
async with injector as ctx_injector:
131-
assert ctx_injector == injector
132-
assert injector._connected
136+
async with injector as ctx_injector:
137+
assert ctx_injector == injector
138+
assert injector._connected
133139

134-
assert not injector._connected
135-
mock_client.aclose.assert_called_once()
140+
assert not injector._connected
141+
mock_client.aclose.assert_called_once()
136142

137143
def test_get_stream_key(self):
138144
"""Test stream key generation."""
139-
injector = RedisStreamInjector()
145+
mock_client = AsyncMock()
146+
injector = RedisStreamInjector(redis_client=mock_client)
140147

141148
key = injector._get_stream_key('test_task')
142149
assert key == 'a2a:task:test_task'
143150

144151
def test_get_stream_key_empty_task_id(self):
145152
"""Test stream key generation with empty task_id."""
146-
injector = RedisStreamInjector()
153+
mock_client = AsyncMock()
154+
injector = RedisStreamInjector(redis_client=mock_client)
147155

148156
with pytest.raises(ValueError, match='task_id cannot be empty'):
149157
injector._get_stream_key('')
150158

151159
def test_serialize_event(self):
152160
"""Test event serialization."""
153-
injector = RedisStreamInjector()
154-
155-
data = {'key': 'value', 'number': 42}
156-
result = injector._serialize_event('TestEvent', data)
161+
injector = RedisStreamInjector(redis_client=AsyncMock())
157162

158-
assert result['type'] == 'TestEvent'
159-
assert 'payload' in result
160-
161-
# Parse the payload to verify it's correct JSON
162-
payload = json.loads(result['payload'])
163-
assert payload == data
163+
event_data = injector._serialize_event('test_type', {'key': 'value'})
164+
assert event_data['type'] == 'test_type'
165+
assert 'payload' in event_data
164166

165167
@pytest.mark.asyncio
166168
async def test_append_to_stream(self):
167169
"""Test appending event to stream."""
168170
mock_client = AsyncMock()
169171
mock_client.xadd = AsyncMock(return_value='123-0')
170172

171-
injector = RedisStreamInjector()
172-
injector._client = mock_client
173-
injector._connected = True
173+
injector = RedisStreamInjector(redis_client=mock_client)
174174

175175
event_data = {'type': 'Test', 'payload': '{"data": "test"}'}
176176
result = await injector._append_to_stream('test_task', event_data)
@@ -183,7 +183,8 @@ async def test_append_to_stream(self):
183183
@pytest.mark.asyncio
184184
async def test_append_to_stream_not_connected(self):
185185
"""Test append_to_stream when not connected."""
186-
injector = RedisStreamInjector()
186+
mock_client = AsyncMock()
187+
injector = RedisStreamInjector(redis_client=mock_client)
187188
injector._connected = False
188189

189190
with pytest.raises(RuntimeError, match='Not connected to Redis'):
@@ -195,9 +196,7 @@ async def test_stream_message_with_dict(self):
195196
mock_client = AsyncMock()
196197
mock_client.xadd = AsyncMock(return_value='123-0')
197198

198-
injector = RedisStreamInjector()
199-
injector._client = mock_client
200-
injector._connected = True
199+
injector = RedisStreamInjector(redis_client=mock_client)
201200

202201
message_data = {'content': 'test message', 'role': 'assistant'}
203202
result = await injector.stream_message(

0 commit comments

Comments
 (0)