1- import json
21import pytest
32from unittest .mock import AsyncMock , patch
43
54from a2a .types import TaskStatusUpdateEvent
65from a2a .utils .stream_write .redis_stream_writer import RedisStreamInjector
76
87
8+ @pytest .fixture
9+ def mock_redis_client ():
10+ """Fixture providing a mock Redis client."""
11+ client = AsyncMock ()
12+ client .xadd = AsyncMock (return_value = '123-0' )
13+ client .ping = AsyncMock ()
14+ client .aclose = AsyncMock ()
15+ return client
16+
17+
918class TestRedisStreamInjector :
1019 """Test suite for RedisStreamInjector."""
1120
@@ -90,8 +99,7 @@ async def test_disconnect(self):
9099 mock_client = AsyncMock ()
91100 mock_client .aclose = AsyncMock ()
92101
93- injector = RedisStreamInjector ()
94- injector ._client = mock_client
102+ injector = RedisStreamInjector (redis_client = mock_client )
95103 injector ._connected = True
96104
97105 await injector .disconnect ()
@@ -103,15 +111,18 @@ async def test_disconnect(self):
103111 @pytest .mark .asyncio
104112 async def test_disconnect_not_connected (self ):
105113 """Test disconnect when not connected."""
106- injector = RedisStreamInjector ()
107- injector . _client = None
114+ mock_client = AsyncMock ()
115+ injector = RedisStreamInjector ( redis_client = mock_client )
108116 injector ._connected = False
109117
110118 await injector .disconnect ()
111119
112- # Should not raise any errors
120+ # Should not call aclose since not connected
121+ mock_client .aclose .assert_not_called ()
122+
123+ # Should not raise any errors and client should remain
113124 assert not injector ._connected
114- assert injector ._client is None
125+ assert injector ._client == mock_client
115126
116127 @pytest .mark .asyncio
117128 async def test_context_manager (self ):
@@ -120,57 +131,46 @@ async def test_context_manager(self):
120131 mock_client .ping = AsyncMock ()
121132 mock_client .aclose = AsyncMock ()
122133
123- with patch (
124- 'a2a.utils.stream_write.redis_stream_writer.Redis'
125- ) as mock_redis_class :
126- mock_redis_class .from_url .return_value = mock_client
127-
128- injector = RedisStreamInjector ()
134+ injector = RedisStreamInjector (redis_client = mock_client )
129135
130- async with injector as ctx_injector :
131- assert ctx_injector == injector
132- assert injector ._connected
136+ async with injector as ctx_injector :
137+ assert ctx_injector == injector
138+ assert injector ._connected
133139
134- assert not injector ._connected
135- mock_client .aclose .assert_called_once ()
140+ assert not injector ._connected
141+ mock_client .aclose .assert_called_once ()
136142
137143 def test_get_stream_key (self ):
138144 """Test stream key generation."""
139- injector = RedisStreamInjector ()
145+ mock_client = AsyncMock ()
146+ injector = RedisStreamInjector (redis_client = mock_client )
140147
141148 key = injector ._get_stream_key ('test_task' )
142149 assert key == 'a2a:task:test_task'
143150
144151 def test_get_stream_key_empty_task_id (self ):
145152 """Test stream key generation with empty task_id."""
146- injector = RedisStreamInjector ()
153+ mock_client = AsyncMock ()
154+ injector = RedisStreamInjector (redis_client = mock_client )
147155
148156 with pytest .raises (ValueError , match = 'task_id cannot be empty' ):
149157 injector ._get_stream_key ('' )
150158
151159 def test_serialize_event (self ):
152160 """Test event serialization."""
153- injector = RedisStreamInjector ()
154-
155- data = {'key' : 'value' , 'number' : 42 }
156- result = injector ._serialize_event ('TestEvent' , data )
161+ injector = RedisStreamInjector (redis_client = AsyncMock ())
157162
158- assert result ['type' ] == 'TestEvent'
159- assert 'payload' in result
160-
161- # Parse the payload to verify it's correct JSON
162- payload = json .loads (result ['payload' ])
163- assert payload == data
163+ event_data = injector ._serialize_event ('test_type' , {'key' : 'value' })
164+ assert event_data ['type' ] == 'test_type'
165+ assert 'payload' in event_data
164166
165167 @pytest .mark .asyncio
166168 async def test_append_to_stream (self ):
167169 """Test appending event to stream."""
168170 mock_client = AsyncMock ()
169171 mock_client .xadd = AsyncMock (return_value = '123-0' )
170172
171- injector = RedisStreamInjector ()
172- injector ._client = mock_client
173- injector ._connected = True
173+ injector = RedisStreamInjector (redis_client = mock_client )
174174
175175 event_data = {'type' : 'Test' , 'payload' : '{"data": "test"}' }
176176 result = await injector ._append_to_stream ('test_task' , event_data )
@@ -183,7 +183,8 @@ async def test_append_to_stream(self):
183183 @pytest .mark .asyncio
184184 async def test_append_to_stream_not_connected (self ):
185185 """Test append_to_stream when not connected."""
186- injector = RedisStreamInjector ()
186+ mock_client = AsyncMock ()
187+ injector = RedisStreamInjector (redis_client = mock_client )
187188 injector ._connected = False
188189
189190 with pytest .raises (RuntimeError , match = 'Not connected to Redis' ):
@@ -195,9 +196,7 @@ async def test_stream_message_with_dict(self):
195196 mock_client = AsyncMock ()
196197 mock_client .xadd = AsyncMock (return_value = '123-0' )
197198
198- injector = RedisStreamInjector ()
199- injector ._client = mock_client
200- injector ._connected = True
199+ injector = RedisStreamInjector (redis_client = mock_client )
201200
202201 message_data = {'content' : 'test message' , 'role' : 'assistant' }
203202 result = await injector .stream_message (
0 commit comments