Skip to content

Commit d2ae534

Browse files
authored
Doctest 60 closes #249 (#529)
* Improve doctest coverage Signed-off-by: Mihai Criveti <[email protected]> * Improve doctest coverage 60% Signed-off-by: Mihai Criveti <[email protected]> --------- Signed-off-by: Mihai Criveti <[email protected]>
1 parent 9be54eb commit d2ae534

File tree

7 files changed

+406
-14
lines changed

7 files changed

+406
-14
lines changed

.github/workflows/pytest.yml

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ jobs:
7070
pip install pytest pytest-cov pytest-asyncio coverage[toml]
7171
7272
# -----------------------------------------------------------
73-
# 3️⃣ Run the tests with coverage
73+
# 3️⃣ Run the tests with coverage (fail under 80% coverage)
7474
# -----------------------------------------------------------
7575
- name: 🧪 Run pytest
7676
run: |
@@ -80,14 +80,20 @@ jobs:
8080
--cov-report=html \
8181
--cov-report=term \
8282
--cov-branch \
83-
--cov-fail-under=40
83+
--cov-fail-under=80
8484
8585
# -----------------------------------------------------------
86-
# 4️⃣ Run doctests
86+
# 4️⃣ Run doctests (fail under 55% coverage)
8787
# -----------------------------------------------------------
88-
- name: 🧪 Run doctests
88+
- name: 📊 Doctest coverage with threshold
8989
run: |
90-
pytest --doctest-modules mcpgateway/ --tb=short
90+
# Run doctests with coverage measurement
91+
pytest --doctest-modules mcpgateway/ \
92+
--cov=mcpgateway \
93+
--cov-report=term \
94+
--cov-report=json:doctest-coverage.json \
95+
--cov-fail-under=55 \
96+
--tb=short
9197
9298
# -----------------------------------------------------------
9399
# 5️⃣ Doctest coverage check

mcpgateway/main.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,28 @@ async def dispatch(self, request: Request, call_next):
342342
343343
Returns:
344344
Response: Either the standard route response or a 401/403 error response.
345+
346+
Examples:
347+
>>> import asyncio
348+
>>> from unittest.mock import Mock, AsyncMock, patch
349+
>>> from fastapi import HTTPException
350+
>>> from fastapi.responses import JSONResponse
351+
>>>
352+
>>> # Test unprotected path - should pass through
353+
>>> middleware = DocsAuthMiddleware(None)
354+
>>> request = Mock()
355+
>>> request.url.path = "/api/tools"
356+
>>> request.headers.get.return_value = None
357+
>>> call_next = AsyncMock(return_value="response")
358+
>>>
359+
>>> result = asyncio.run(middleware.dispatch(request, call_next))
360+
>>> result
361+
'response'
362+
>>>
363+
>>> # Test that middleware checks protected paths
364+
>>> request.url.path = "/docs"
365+
>>> isinstance(middleware, DocsAuthMiddleware)
366+
True
345367
"""
346368
protected_paths = ["/docs", "/redoc", "/openapi.json"]
347369

@@ -386,6 +408,36 @@ async def __call__(self, scope, receive, send):
386408
scope (dict): The ASGI connection scope.
387409
receive (Callable): Awaitable that yields events from the client.
388410
send (Callable): Awaitable used to send events to the client.
411+
412+
Examples:
413+
>>> import asyncio
414+
>>> from unittest.mock import AsyncMock, patch
415+
>>>
416+
>>> # Test non-HTTP request passthrough
417+
>>> app_mock = AsyncMock()
418+
>>> middleware = MCPPathRewriteMiddleware(app_mock)
419+
>>> scope = {"type": "websocket", "path": "/ws"}
420+
>>> receive = AsyncMock()
421+
>>> send = AsyncMock()
422+
>>>
423+
>>> asyncio.run(middleware(scope, receive, send))
424+
>>> app_mock.assert_called_once_with(scope, receive, send)
425+
>>>
426+
>>> # Test path rewriting for /servers/123/mcp
427+
>>> app_mock.reset_mock()
428+
>>> scope = {"type": "http", "path": "/servers/123/mcp"}
429+
>>> with patch('mcpgateway.main.streamable_http_auth', return_value=True):
430+
... with patch.object(streamable_http_session, 'handle_streamable_http') as mock_handler:
431+
... asyncio.run(middleware(scope, receive, send))
432+
... scope["path"]
433+
'/mcp'
434+
>>>
435+
>>> # Test regular path (no rewrite)
436+
>>> scope = {"type": "http", "path": "/tools"}
437+
>>> with patch('mcpgateway.main.streamable_http_auth', return_value=True):
438+
... asyncio.run(middleware(scope, receive, send))
439+
... scope["path"]
440+
'/tools'
389441
"""
390442
# Only handle HTTP requests, HTTPS uses scope["type"] == "http" in ASGI
391443
if scope["type"] != "http":

mcpgateway/schemas.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ def to_camel_case(s: str) -> str:
6060
'alreadyCamel'
6161
>>> to_camel_case("")
6262
''
63+
>>> to_camel_case("single")
64+
'single'
65+
>>> to_camel_case("_leading_underscore")
66+
'LeadingUnderscore'
67+
>>> to_camel_case("trailing_underscore_")
68+
'trailingUnderscore'
6369
"""
6470
return "".join(word.capitalize() if i else word for i, word in enumerate(s.split("_")))
6571

@@ -120,6 +126,22 @@ def to_dict(self, use_alias: bool = False) -> Dict[str, Any]:
120126
>>> m = ExampleModel(foo=1, bar='baz')
121127
>>> m.to_dict()
122128
{'foo': 1, 'bar': 'baz'}
129+
130+
>>> # Test with alias
131+
>>> m.to_dict(use_alias=True)
132+
{'foo': 1, 'bar': 'baz'}
133+
134+
>>> # Test with nested model
135+
>>> class NestedModel(BaseModelWithConfigDict):
136+
... nested_field: int
137+
>>> class ParentModel(BaseModelWithConfigDict):
138+
... parent_field: str
139+
... child: NestedModel
140+
>>> nested = NestedModel(nested_field=42)
141+
>>> parent = ParentModel(parent_field="test", child=nested)
142+
>>> result = parent.to_dict()
143+
>>> result['child']
144+
{'nested_field': 42}
123145
"""
124146
output = {}
125147
for key, value in self.model_dump(by_alias=use_alias).items():
@@ -407,6 +429,35 @@ def validate_request_type(cls, v: str, info: ValidationInfo) -> str:
407429
408430
Raises:
409431
ValueError: When value is unsafe
432+
433+
Examples:
434+
>>> # Test MCP integration types
435+
>>> from pydantic import ValidationInfo
436+
>>> info = type('obj', (object,), {'data': {'integration_type': 'MCP'}})
437+
>>> ToolCreate.validate_request_type('SSE', info)
438+
'SSE'
439+
440+
>>> # Test REST integration types
441+
>>> info = type('obj', (object,), {'data': {'integration_type': 'REST'}})
442+
>>> ToolCreate.validate_request_type('GET', info)
443+
'GET'
444+
>>> ToolCreate.validate_request_type('POST', info)
445+
'POST'
446+
447+
>>> # Test invalid REST type
448+
>>> try:
449+
... ToolCreate.validate_request_type('SSE', info)
450+
... except ValueError as e:
451+
... "not allowed for REST" in str(e)
452+
True
453+
454+
>>> # Test invalid MCP type
455+
>>> info = type('obj', (object,), {'data': {'integration_type': 'MCP'}})
456+
>>> try:
457+
... ToolCreate.validate_request_type('GET', info)
458+
... except ValueError as e:
459+
... "not allowed for MCP" in str(e)
460+
True
410461
"""
411462
data = info.data
412463
integration_type = data.get("integration_type")
@@ -434,6 +485,33 @@ def assemble_auth(cls, values: Dict[str, Any]) -> Dict[str, Any]:
434485
435486
Returns:
436487
Dict: Reformatedd values dict
488+
489+
Examples:
490+
>>> # Test basic auth
491+
>>> values = {'auth_type': 'basic', 'auth_username': 'user', 'auth_password': 'pass'}
492+
>>> result = ToolCreate.assemble_auth(values)
493+
>>> 'auth' in result
494+
True
495+
>>> result['auth']['auth_type']
496+
'basic'
497+
498+
>>> # Test bearer auth
499+
>>> values = {'auth_type': 'bearer', 'auth_token': 'mytoken'}
500+
>>> result = ToolCreate.assemble_auth(values)
501+
>>> result['auth']['auth_type']
502+
'bearer'
503+
504+
>>> # Test authheaders
505+
>>> values = {'auth_type': 'authheaders', 'auth_header_key': 'X-API-Key', 'auth_header_value': 'secret'}
506+
>>> result = ToolCreate.assemble_auth(values)
507+
>>> result['auth']['auth_type']
508+
'authheaders'
509+
510+
>>> # Test no auth type
511+
>>> values = {'name': 'test'}
512+
>>> result = ToolCreate.assemble_auth(values)
513+
>>> 'auth' in result
514+
False
437515
"""
438516
logger.debug(
439517
"Assembling auth in ToolCreate with raw values",
@@ -519,6 +597,17 @@ def validate_description(cls, v: Optional[str]) -> Optional[str]:
519597
520598
Raises:
521599
ValueError: When value is unsafe
600+
601+
Examples:
602+
>>> from mcpgateway.schemas import ResourceCreate
603+
>>> ResourceCreate.validate_description('A safe description')
604+
'A safe description'
605+
>>> ResourceCreate.validate_description(None) # Test None case
606+
607+
>>> ResourceCreate.validate_description('x' * 5000)
608+
Traceback (most recent call last):
609+
...
610+
ValueError: ...
522611
"""
523612
if v is None:
524613
return v

mcpgateway/translate.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,20 @@ async def publish(self, data: str) -> None:
137137
... return True
138138
>>> asyncio.run(test_publish())
139139
True
140+
141+
>>> # Test queue full handling
142+
>>> async def test_full_queue():
143+
... pubsub = _PubSub()
144+
... # Create a queue with size 1
145+
... q = asyncio.Queue(maxsize=1)
146+
... pubsub._subscribers = [q]
147+
... # Fill the queue
148+
... await q.put("first")
149+
... # This should remove the full queue
150+
... await pubsub.publish("second")
151+
... return len(pubsub._subscribers)
152+
>>> asyncio.run(test_full_queue())
153+
0
140154
"""
141155
dead: List[asyncio.Queue[str]] = []
142156
for q in self._subscribers:
@@ -405,6 +419,12 @@ def _build_fastapi(
405419
True
406420
>>> "/send" in [r.path for r in app2.routes]
407421
True
422+
423+
>>> # Test CORS middleware is added
424+
>>> app3 = _build_fastapi(pubsub, stdio, cors_origins=["http://example.com"])
425+
>>> # Check that middleware stack includes CORSMiddleware
426+
>>> any("CORSMiddleware" in str(m) for m in app3.user_middleware)
427+
True
408428
"""
409429
app = FastAPI()
410430

mcpgateway/utils/db_isready.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,44 @@ def _parse_cli() -> argparse.Namespace:
320320
321321
Returns:
322322
Parsed :class:`argparse.Namespace` holding all CLI options.
323+
324+
Examples:
325+
>>> import sys
326+
>>> # Save original argv
327+
>>> original_argv = sys.argv
328+
>>>
329+
>>> # Test default values
330+
>>> sys.argv = ['db_isready.py']
331+
>>> args = _parse_cli()
332+
>>> args.database_url == DEFAULT_DB_URL
333+
True
334+
>>> args.max_tries == DEFAULT_MAX_TRIES
335+
True
336+
>>> args.interval == DEFAULT_INTERVAL
337+
True
338+
>>> args.timeout == DEFAULT_TIMEOUT
339+
True
340+
>>> args.log_level == DEFAULT_LOG_LEVEL
341+
True
342+
343+
>>> # Test custom values
344+
>>> sys.argv = ['db_isready.py', '--database-url', 'postgresql://localhost/test',
345+
... '--max-tries', '5', '--interval', '1.5', '--timeout', '10',
346+
... '--log-level', 'DEBUG']
347+
>>> args = _parse_cli()
348+
>>> args.database_url
349+
'postgresql://localhost/test'
350+
>>> args.max_tries
351+
5
352+
>>> args.interval
353+
1.5
354+
>>> args.timeout
355+
10
356+
>>> args.log_level
357+
'DEBUG'
358+
359+
>>> # Restore original argv
360+
>>> sys.argv = original_argv
323361
"""
324362

325363
parser = argparse.ArgumentParser(

0 commit comments

Comments
 (0)