Skip to content

Commit eb88f88

Browse files
committed
Add comprehensive unit tests for JSON argument conversion
- Add TestJSONArgumentConversion class with 9 comprehensive tests - Test basic dict conversion, Optional/Union types, typed dicts - Test error handling for invalid JSON strings - Test async and sync function wrappers - Test complex nested JSON structures - Test annotation modification behavior - Test that non-dict parameters remain unchanged - Add TestJSONSchemaModification class with schema tests - Test tool registration with schema modification - Test multiple dict parameters in single function - Fix union type handling for Python 3.10+ syntax (dict | None) - Support both typing.Union and new UnionType for compatibility - Update schema modification to handle both union syntaxes All 44 tests passing, including 11 new JSON conversion tests
1 parent 9ec361c commit eb88f88

File tree

2 files changed

+263
-5
lines changed

2 files changed

+263
-5
lines changed

jupyter_server_mcp/mcp_server.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,18 @@ def _should_convert_to_dict(annotation, value):
4747
if annotation is dict:
4848
return True
4949

50-
# Optional[dict] or Union[dict, None] etc.
50+
# Optional[dict] or Union[dict, None] etc. (old typing.Union)
5151
origin = get_origin(annotation)
5252
if origin is Union:
5353
args = get_args(annotation)
5454
return dict in args
5555

56+
# New Python 3.10+ union syntax: dict | None
57+
if hasattr(annotation, '__class__') and annotation.__class__.__name__ == 'UnionType':
58+
# For dict | None style unions
59+
args = get_args(annotation)
60+
return dict in args
61+
5662
# Dict[K, V] style annotations
5763
return bool(hasattr(annotation, '__origin__') and annotation.__origin__ is dict)
5864

@@ -62,7 +68,7 @@ def _modify_annotation_for_string_support(annotation):
6268
if annotation is dict:
6369
return dict | str
6470

65-
# Optional[dict] or Union[dict, None] etc.
71+
# Optional[dict] or Union[dict, None] etc. (old typing.Union)
6672
origin = get_origin(annotation)
6773
if origin is Union:
6874
args = get_args(annotation)
@@ -71,6 +77,21 @@ def _modify_annotation_for_string_support(annotation):
7177
if str not in args:
7278
return Union[(*tuple(args), str)]
7379
return annotation
80+
81+
# New Python 3.10+ union syntax: dict | None
82+
if hasattr(annotation, '__class__') and annotation.__class__.__name__ == 'UnionType':
83+
args = get_args(annotation)
84+
if dict in args:
85+
# Add str to the union if it's not already there
86+
if str not in args:
87+
# Reconstruct the union with str added
88+
new_args = (*tuple(args), str)
89+
# Create new union type
90+
result = new_args[0]
91+
for arg in new_args[1:]:
92+
result = result | arg
93+
return result
94+
return annotation
7495

7596
# Dict[K, V] style annotations -> annotation | str
7697
if hasattr(annotation, '__origin__') and annotation.__origin__ is dict:
@@ -175,8 +196,8 @@ def _modify_schema_for_json_string_support(func: Callable, tool) -> None:
175196
# Direct dict annotation
176197
if annotation is dict:
177198
should_support_string = True
178-
# Optional[dict] or Union[dict, None] etc.
179-
elif get_origin(annotation) is Union:
199+
# Optional[dict] or Union[dict, None] etc. (old typing.Union)
200+
elif get_origin(annotation) is Union or (hasattr(annotation, '__class__') and annotation.__class__.__name__ == 'UnionType'):
180201
args = get_args(annotation)
181202
if dict in args:
182203
should_support_string = True

tests/test_mcp_server.py

Lines changed: 238 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import pytest
66

7-
from jupyter_server_mcp.mcp_server import MCPServer
7+
from jupyter_server_mcp.mcp_server import MCPServer, _auto_convert_json_args
88

99

1010
def simple_function(x: int, y: int) -> int:
@@ -212,3 +212,240 @@ def test_server_with_multiple_tools(self):
212212
assert server._registered_tools["simple_function"]["is_async"] is False
213213
assert server._registered_tools["async_function"]["is_async"] is True
214214
assert server._registered_tools["printer"]["is_async"] is False
215+
216+
217+
class TestJSONArgumentConversion:
218+
"""Test JSON argument conversion functionality."""
219+
220+
def test_simple_dict_conversion(self):
221+
"""Test basic JSON string to dict conversion."""
222+
223+
def func_with_dict(data: dict) -> dict:
224+
"""Function that expects a dict."""
225+
return {"received": data, "type": type(data).__name__}
226+
227+
wrapped_func = _auto_convert_json_args(func_with_dict)
228+
229+
# Test with actual dict (should pass through)
230+
result = wrapped_func(data={"key": "value"})
231+
assert result["received"] == {"key": "value"}
232+
assert result["type"] == "dict"
233+
234+
# Test with JSON string (should be converted)
235+
result = wrapped_func(data='{"key": "value"}')
236+
assert result["received"] == {"key": "value"}
237+
assert result["type"] == "dict"
238+
239+
def test_optional_dict_conversion(self):
240+
"""Test JSON conversion with Optional[dict] annotation."""
241+
242+
def func_with_optional_dict(data: dict | None = None) -> dict:
243+
"""Function with optional dict parameter."""
244+
return {"received": data, "type": type(data).__name__ if data else "NoneType"}
245+
246+
wrapped_func = _auto_convert_json_args(func_with_optional_dict)
247+
248+
# Test with None (should pass through)
249+
result = wrapped_func(data=None)
250+
assert result["received"] is None
251+
assert result["type"] == "NoneType"
252+
253+
# Test with JSON string (should be converted)
254+
result = wrapped_func(data='{"optional": true}')
255+
assert result["received"] == {"optional": True}
256+
assert result["type"] == "dict"
257+
258+
def test_union_dict_conversion(self):
259+
"""Test JSON conversion with Union type annotations."""
260+
261+
def func_with_union_dict(data: dict | None) -> dict:
262+
"""Function with Union[dict, None] parameter."""
263+
return {"received": data, "type": type(data).__name__ if data else "NoneType"}
264+
265+
wrapped_func = _auto_convert_json_args(func_with_union_dict)
266+
267+
# Test with JSON string (should be converted)
268+
result = wrapped_func(data='{"union": "test"}')
269+
assert result["received"] == {"union": "test"}
270+
assert result["type"] == "dict"
271+
272+
def test_typed_dict_conversion(self):
273+
"""Test JSON conversion with typed dict annotations."""
274+
275+
def func_with_typed_dict(config: dict[str, str]) -> dict:
276+
"""Function with Dict[str, str] annotation."""
277+
return {"received": config, "type": type(config).__name__}
278+
279+
wrapped_func = _auto_convert_json_args(func_with_typed_dict)
280+
281+
# Test with JSON string (should be converted)
282+
result = wrapped_func(config='{"name": "test", "value": "data"}')
283+
assert result["received"] == {"name": "test", "value": "data"}
284+
assert result["type"] == "dict"
285+
286+
def test_invalid_json_handling(self):
287+
"""Test handling of invalid JSON strings."""
288+
289+
def func_with_dict(data: dict) -> dict:
290+
"""Function that expects a dict."""
291+
return {"received": data, "type": type(data).__name__}
292+
293+
wrapped_func = _auto_convert_json_args(func_with_dict)
294+
295+
# Test with invalid JSON (should pass string as-is)
296+
result = wrapped_func(data="invalid json {")
297+
assert result["received"] == "invalid json {"
298+
assert result["type"] == "str"
299+
300+
# Test with empty string (should pass as-is)
301+
result = wrapped_func(data="")
302+
assert result["received"] == ""
303+
assert result["type"] == "str"
304+
305+
def test_non_dict_parameters_unchanged(self):
306+
"""Test that non-dict parameters are not affected."""
307+
308+
def mixed_func(name: str, count: int, data: dict) -> dict:
309+
"""Function with mixed parameter types."""
310+
return {
311+
"name": name,
312+
"name_type": type(name).__name__,
313+
"count": count,
314+
"count_type": type(count).__name__,
315+
"data": data,
316+
"data_type": type(data).__name__
317+
}
318+
319+
wrapped_func = _auto_convert_json_args(mixed_func)
320+
321+
# Only the dict parameter should be converted
322+
result = wrapped_func(
323+
name="test",
324+
count=42,
325+
data='{"converted": true}'
326+
)
327+
328+
assert result["name"] == "test"
329+
assert result["name_type"] == "str"
330+
assert result["count"] == 42
331+
assert result["count_type"] == "int"
332+
assert result["data"] == {"converted": True}
333+
assert result["data_type"] == "dict"
334+
335+
@pytest.mark.asyncio
336+
async def test_async_function_conversion(self):
337+
"""Test JSON conversion with async functions."""
338+
339+
async def async_func_with_dict(config: dict) -> dict:
340+
"""Async function that expects a dict."""
341+
await asyncio.sleep(0.001) # Small delay
342+
return {"async_result": config, "type": type(config).__name__}
343+
344+
wrapped_func = _auto_convert_json_args(async_func_with_dict)
345+
346+
# Test with JSON string (should be converted)
347+
result = await wrapped_func(config='{"async": true, "value": 123}')
348+
assert result["async_result"] == {"async": True, "value": 123}
349+
assert result["type"] == "dict"
350+
351+
def test_complex_nested_json(self):
352+
"""Test conversion of complex nested JSON structures."""
353+
354+
def func_with_nested_dict(data: dict) -> dict:
355+
"""Function that processes nested dict data."""
356+
return {"processed": data}
357+
358+
wrapped_func = _auto_convert_json_args(func_with_nested_dict)
359+
360+
complex_json = '''{
361+
"users": [
362+
{"name": "Alice", "age": 30},
363+
{"name": "Bob", "age": 25}
364+
],
365+
"metadata": {
366+
"version": "1.0",
367+
"created": "2024-01-01"
368+
}
369+
}'''
370+
371+
result = wrapped_func(data=complex_json)
372+
expected = {
373+
"users": [
374+
{"name": "Alice", "age": 30},
375+
{"name": "Bob", "age": 25}
376+
],
377+
"metadata": {
378+
"version": "1.0",
379+
"created": "2024-01-01"
380+
}
381+
}
382+
assert result["processed"] == expected
383+
384+
def test_annotation_modification(self):
385+
"""Test that function annotations are properly modified."""
386+
387+
def original_func(data: dict) -> dict:
388+
"""Original function with dict annotation."""
389+
return data
390+
391+
wrapped_func = _auto_convert_json_args(original_func)
392+
393+
# Check that annotations were modified to accept strings
394+
annotations = wrapped_func.__annotations__
395+
assert 'data' in annotations
396+
397+
# The annotation should now be dict | str (or Union equivalent)
398+
data_annotation = annotations['data']
399+
# We can check this works by ensuring both dict and str are acceptable
400+
assert hasattr(data_annotation, '__args__') or data_annotation == (dict | str)
401+
402+
403+
class TestJSONSchemaModification:
404+
"""Test JSON schema modification for MCP tools."""
405+
406+
def test_schema_modification_applied(self):
407+
"""Test that schema modification is applied during tool registration."""
408+
server = MCPServer()
409+
410+
def func_with_dict_param(config: dict) -> str:
411+
"""Function with dict parameter."""
412+
return f"Received config: {config}"
413+
414+
# Register the function - schema should be automatically modified
415+
server.register_tool(func_with_dict_param)
416+
417+
# Verify the tool was registered
418+
assert "func_with_dict_param" in server._registered_tools
419+
tool_info = server._registered_tools["func_with_dict_param"]
420+
assert tool_info["name"] == "func_with_dict_param"
421+
422+
def test_multiple_dict_parameters(self):
423+
"""Test conversion with multiple dict parameters."""
424+
425+
def func_multiple_dicts(config: dict, metadata: dict, name: str) -> dict:
426+
"""Function with multiple dict parameters."""
427+
return {
428+
"config": config,
429+
"metadata": metadata,
430+
"name": name,
431+
"types": {
432+
"config": type(config).__name__,
433+
"metadata": type(metadata).__name__,
434+
"name": type(name).__name__
435+
}
436+
}
437+
438+
wrapped_func = _auto_convert_json_args(func_multiple_dicts)
439+
440+
result = wrapped_func(
441+
config='{"key1": "value1"}',
442+
metadata='{"version": 2}',
443+
name="test_function"
444+
)
445+
446+
assert result["config"] == {"key1": "value1"}
447+
assert result["metadata"] == {"version": 2}
448+
assert result["name"] == "test_function"
449+
assert result["types"]["config"] == "dict"
450+
assert result["types"]["metadata"] == "dict"
451+
assert result["types"]["name"] == "str"

0 commit comments

Comments
 (0)