Skip to content

Commit bca039a

Browse files
fix: address code review issues in HTTP Stream implementation
- Fix error propagation in write_stream method - Add proper shutdown handling in SSE listener to prevent resource leaks - Implement proper type annotations based on schema types - Add _closing flag to gracefully stop SSE listener - Create basic test to verify transport selection logic Co-authored-by: Mervin Praison <[email protected]>
1 parent 7b76e43 commit bca039a

File tree

3 files changed

+118
-7
lines changed

3 files changed

+118
-7
lines changed

src/praisonai-agents/praisonaiagents/mcp/mcp_http_stream.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,30 @@ def __init__(self, name: str, description: str, session: ClientSession, input_sc
5050
# Create a signature based on input schema
5151
params = []
5252
if input_schema and 'properties' in input_schema:
53-
for param_name in input_schema['properties']:
53+
for param_name, prop_schema in input_schema['properties'].items():
54+
# Determine type annotation based on schema
55+
prop_type = prop_schema.get('type', 'string') if isinstance(prop_schema, dict) else 'string'
56+
if prop_type == 'string':
57+
annotation = str
58+
elif prop_type == 'integer':
59+
annotation = int
60+
elif prop_type == 'number':
61+
annotation = float
62+
elif prop_type == 'boolean':
63+
annotation = bool
64+
elif prop_type == 'array':
65+
annotation = list
66+
elif prop_type == 'object':
67+
annotation = dict
68+
else:
69+
annotation = Any
70+
5471
params.append(
5572
inspect.Parameter(
5673
name=param_name,
5774
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
5875
default=inspect.Parameter.empty if param_name in input_schema.get('required', []) else None,
59-
annotation=str # Default to string
76+
annotation=annotation
6077
)
6178
)
6279

@@ -174,12 +191,16 @@ def __init__(self, base_url: str, session_id: Optional[str] = None, options: Opt
174191
self._sse_task = None
175192
self._message_queue = asyncio.Queue()
176193
self._pending_requests = {}
194+
self._closing = False
177195

178196
async def __aenter__(self):
179197
self._session = aiohttp.ClientSession()
180198
return self
181199

182200
async def __aexit__(self, exc_type, exc_val, exc_tb):
201+
# Set closing flag to stop listener gracefully
202+
self._closing = True
203+
183204
if self._sse_task:
184205
self._sse_task.cancel()
185206
try:
@@ -254,6 +275,10 @@ async def _sse_listener(self):
254275
"""Background task to listen for SSE events."""
255276
while True:
256277
try:
278+
# Check if we should stop
279+
if hasattr(self, '_closing') and self._closing:
280+
break
281+
257282
url = self.base_url
258283
if self.session_id:
259284
# Add session as query parameter for SSE connection
@@ -269,6 +294,10 @@ async def _sse_listener(self):
269294
async with self._session.get(url, headers=headers) as response:
270295
buffer = ""
271296
async for chunk in response.content:
297+
# Check if we should stop
298+
if hasattr(self, '_closing') and self._closing:
299+
break
300+
272301
buffer += chunk.decode('utf-8')
273302

274303
# Process complete SSE events
@@ -289,9 +318,15 @@ async def _sse_listener(self):
289318
except json.JSONDecodeError:
290319
logger.error(f"Failed to parse SSE event: {data}")
291320

321+
except asyncio.CancelledError:
322+
# Proper shutdown
323+
break
292324
except Exception as e:
293-
logger.error(f"SSE listener error: {e}")
294-
await asyncio.sleep(1) # Reconnect after 1 second
325+
if not (hasattr(self, '_closing') and self._closing):
326+
logger.error(f"SSE listener error: {e}")
327+
await asyncio.sleep(1) # Reconnect after 1 second
328+
else:
329+
break
295330

296331
def read_stream(self):
297332
"""Create a read stream for the ClientSession."""
@@ -306,7 +341,8 @@ def write_stream(self):
306341
async def _write(message):
307342
if hasattr(message, 'to_dict'):
308343
message = message.to_dict()
309-
await self.send_request(message)
344+
response = await self.send_request(message)
345+
return response
310346
return _write
311347

312348

src/praisonai-agents/praisonaiagents/mcp/mcp_sse.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,30 @@ def __init__(self, name: str, description: str, session: ClientSession, input_sc
4444
# Create a signature based on input schema
4545
params = []
4646
if input_schema and 'properties' in input_schema:
47-
for param_name in input_schema['properties']:
47+
for param_name, prop_schema in input_schema['properties'].items():
48+
# Determine type annotation based on schema
49+
prop_type = prop_schema.get('type', 'string') if isinstance(prop_schema, dict) else 'string'
50+
if prop_type == 'string':
51+
annotation = str
52+
elif prop_type == 'integer':
53+
annotation = int
54+
elif prop_type == 'number':
55+
annotation = float
56+
elif prop_type == 'boolean':
57+
annotation = bool
58+
elif prop_type == 'array':
59+
annotation = list
60+
elif prop_type == 'object':
61+
annotation = dict
62+
else:
63+
annotation = str # Default to string for SSE
64+
4865
params.append(
4966
inspect.Parameter(
5067
name=param_name,
5168
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
5269
default=inspect.Parameter.empty if param_name in input_schema.get('required', []) else None,
53-
annotation=str # Default to string
70+
annotation=annotation
5471
)
5572
)
5673

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#!/usr/bin/env python
2+
"""
3+
Basic test to verify HTTP Stream implementation is working correctly.
4+
"""
5+
6+
import sys
7+
import os
8+
9+
# Add parent directory to path for imports
10+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
11+
12+
from praisonaiagents.mcp import MCP
13+
14+
def test_transport_selection():
15+
"""Test that URLs are correctly routed to appropriate transports."""
16+
17+
print("Testing transport selection logic...")
18+
19+
# Test 1: SSE URL should use SSE transport
20+
try:
21+
mcp_sse = MCP("http://localhost:8080/sse")
22+
assert mcp_sse.is_sse == True
23+
assert mcp_sse.is_http_stream == False
24+
print("✓ SSE URL correctly uses SSE transport")
25+
except Exception as e:
26+
print(f"✗ SSE URL test failed: {e}")
27+
28+
# Test 2: Regular HTTP URL should use HTTP Stream transport
29+
try:
30+
mcp_http = MCP("http://localhost:8080")
31+
assert mcp_http.is_sse == False
32+
assert mcp_http.is_http_stream == True
33+
print("✓ HTTP URL correctly uses HTTP Stream transport")
34+
except Exception as e:
35+
print(f"✗ HTTP URL test failed: {e}")
36+
37+
# Test 3: Custom endpoint should use HTTP Stream transport
38+
try:
39+
mcp_custom = MCP("http://localhost:8080/custom")
40+
assert mcp_custom.is_sse == False
41+
assert mcp_custom.is_http_stream == True
42+
print("✓ Custom endpoint correctly uses HTTP Stream transport")
43+
except Exception as e:
44+
print(f"✗ Custom endpoint test failed: {e}")
45+
46+
# Test 4: Stdio transport should still work
47+
try:
48+
mcp_stdio = MCP("python /path/to/server.py")
49+
assert mcp_stdio.is_sse == False
50+
assert mcp_stdio.is_http_stream == False
51+
print("✓ Stdio transport still works")
52+
except Exception as e:
53+
print(f"✗ Stdio transport test failed: {e}")
54+
55+
print("\nAll transport selection tests completed!")
56+
57+
if __name__ == "__main__":
58+
test_transport_selection()

0 commit comments

Comments
 (0)