Skip to content

Commit 9ffd085

Browse files
Chore/doctest services (#449)
* Add doctring and doctest to all services in mcpgateway folder Signed-off-by: Manav Gupta <[email protected]> * Restore detailed return value documentation in aggregate_metrics method Signed-off-by: Manav Gupta <[email protected]> --------- Signed-off-by: Manav Gupta <[email protected]> Co-authored-by: Mihai Criveti <[email protected]>
1 parent 25a1e44 commit 9ffd085

File tree

9 files changed

+903
-117
lines changed

9 files changed

+903
-117
lines changed

mcpgateway/services/completion_service.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,19 @@ async def handle_completion(self, db: Session, request: Dict[str, Any]) -> Compl
6363
6464
Raises:
6565
CompletionError: If completion fails
66+
67+
Examples:
68+
>>> from mcpgateway.services.completion_service import CompletionService
69+
>>> from unittest.mock import MagicMock
70+
>>> service = CompletionService()
71+
>>> db = MagicMock()
72+
>>> request = {'ref': {'type': 'ref/prompt', 'name': 'prompt1'}, 'argument': {'name': 'arg1', 'value': ''}}
73+
>>> db.execute.return_value.scalars.return_value.all.return_value = []
74+
>>> import asyncio
75+
>>> try:
76+
... asyncio.run(service.handle_completion(db, request))
77+
... except Exception:
78+
... pass
6679
"""
6780
try:
6881
# Get reference and argument info
@@ -191,6 +204,13 @@ def register_completions(self, arg_name: str, values: List[str]) -> None:
191204
Args:
192205
arg_name: Argument name
193206
values: Completion values
207+
208+
Examples:
209+
>>> from mcpgateway.services.completion_service import CompletionService
210+
>>> service = CompletionService()
211+
>>> service.register_completions('arg1', ['a', 'b'])
212+
>>> service._custom_completions['arg1']
213+
['a', 'b']
194214
"""
195215
self._custom_completions[arg_name] = list(values)
196216

@@ -199,5 +219,13 @@ def unregister_completions(self, arg_name: str) -> None:
199219
200220
Args:
201221
arg_name: Argument name
222+
223+
Examples:
224+
>>> from mcpgateway.services.completion_service import CompletionService
225+
>>> service = CompletionService()
226+
>>> service.register_completions('arg1', ['a', 'b'])
227+
>>> service.unregister_completions('arg1')
228+
>>> 'arg1' in service._custom_completions
229+
False
202230
"""
203231
self._custom_completions.pop(arg_name, None)

mcpgateway/services/gateway_service.py

Lines changed: 132 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,23 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
188188
ValueError: If required values are missing
189189
RuntimeError: If there is an error during processing that is not covered by other exceptions
190190
BaseException: If an unexpected error occurs
191+
192+
Examples:
193+
>>> from mcpgateway.services.gateway_service import GatewayService
194+
>>> from unittest.mock import MagicMock
195+
>>> service = GatewayService()
196+
>>> db = MagicMock()
197+
>>> gateway = MagicMock()
198+
>>> db.execute.return_value.scalar_one_or_none.return_value = None
199+
>>> db.add = MagicMock()
200+
>>> db.commit = MagicMock()
201+
>>> db.refresh = MagicMock()
202+
>>> service._notify_gateway_added = MagicMock()
203+
>>> import asyncio
204+
>>> try:
205+
... asyncio.run(service.register_gateway(db, gateway))
206+
... except Exception:
207+
... pass
191208
"""
192209
try:
193210
# Check for name conflicts (both active and inactive)
@@ -279,6 +296,20 @@ async def list_gateways(self, db: Session, include_inactive: bool = False) -> Li
279296
280297
Returns:
281298
List of registered gateways
299+
300+
Examples:
301+
>>> from mcpgateway.services.gateway_service import GatewayService
302+
>>> from unittest.mock import MagicMock
303+
>>> from mcpgateway.schemas import GatewayRead
304+
>>> service = GatewayService()
305+
>>> db = MagicMock()
306+
>>> gateway_obj = MagicMock()
307+
>>> db.execute.return_value.scalars.return_value.all.return_value = [gateway_obj]
308+
>>> GatewayRead.model_validate = MagicMock(return_value='gateway_read')
309+
>>> import asyncio
310+
>>> result = asyncio.run(service.list_gateways(db))
311+
>>> result == ['gateway_read']
312+
True
282313
"""
283314
query = select(DbGateway)
284315

@@ -391,18 +422,31 @@ async def update_gateway(self, db: Session, gateway_id: str, gateway_update: Gat
391422
raise GatewayError(f"Failed to update gateway: {str(e)}")
392423

393424
async def get_gateway(self, db: Session, gateway_id: str, include_inactive: bool = True) -> GatewayRead:
394-
"""Get a specific gateway by ID.
425+
"""
426+
Get a gateway by its ID.
395427
396428
Args:
397429
db: Database session
398430
gateway_id: Gateway ID
399431
include_inactive: Whether to include inactive gateways
400432
401433
Returns:
402-
Gateway information
434+
GatewayRead object
403435
404436
Raises:
405-
GatewayNotFoundError: If gateway not found
437+
GatewayNotFoundError: If the gateway is not found
438+
439+
Examples:
440+
>>> from mcpgateway.services.gateway_service import GatewayService
441+
>>> from unittest.mock import MagicMock
442+
>>> service = GatewayService()
443+
>>> db = MagicMock()
444+
>>> db.get.return_value = MagicMock()
445+
>>> import asyncio
446+
>>> try:
447+
... asyncio.run(service.get_gateway(db, 'gateway_id'))
448+
... except Exception:
449+
... pass
406450
"""
407451
gateway = db.get(DbGateway, gateway_id)
408452
if not gateway:
@@ -414,21 +458,39 @@ async def get_gateway(self, db: Session, gateway_id: str, include_inactive: bool
414458
raise GatewayNotFoundError(f"Gateway not found: {gateway_id}")
415459

416460
async def toggle_gateway_status(self, db: Session, gateway_id: str, activate: bool, reachable: bool = True, only_update_reachable: bool = False) -> GatewayRead:
417-
"""Toggle gateway active status.
461+
"""
462+
Toggle the activation status of a gateway.
418463
419464
Args:
420465
db: Database session
421-
gateway_id: Gateway ID to toggle
466+
gateway_id: Gateway ID
422467
activate: True to activate, False to deactivate
423-
reachable: True if the gateway is reachable, False otherwise
424-
only_update_reachable: If True, only updates reachable status without changing enabled status. Applicable for changing tool status. If the tool is manually deactivated, it will not be reactivated if reachable.
468+
reachable: Whether the gateway is reachable
469+
only_update_reachable: Only update reachable status
425470
426471
Returns:
427-
Updated gateway information
472+
The updated GatewayRead object
428473
429474
Raises:
430-
GatewayNotFoundError: If gateway not found
475+
GatewayNotFoundError: If the gateway is not found
431476
GatewayError: For other errors
477+
478+
Examples:
479+
>>> from mcpgateway.services.gateway_service import GatewayService
480+
>>> from unittest.mock import MagicMock
481+
>>> service = GatewayService()
482+
>>> db = MagicMock()
483+
>>> gateway = MagicMock()
484+
>>> db.get.return_value = gateway
485+
>>> db.commit = MagicMock()
486+
>>> db.refresh = MagicMock()
487+
>>> service._notify_gateway_activated = MagicMock()
488+
>>> service._notify_gateway_deactivated = MagicMock()
489+
>>> import asyncio
490+
>>> try:
491+
... asyncio.run(service.toggle_gateway_status(db, 'gateway_id', True))
492+
... except Exception:
493+
... pass
432494
"""
433495
try:
434496
gateway = db.get(DbGateway, gateway_id)
@@ -523,15 +585,32 @@ async def _notify_gateway_updated(self, gateway: DbGateway) -> None:
523585
await self._publish_event(event)
524586

525587
async def delete_gateway(self, db: Session, gateway_id: str) -> None:
526-
"""Permanently delete a gateway.
588+
"""
589+
Delete a gateway by its ID.
527590
528591
Args:
529592
db: Database session
530-
gateway_id: Gateway ID to delete
593+
gateway_id: Gateway ID
531594
532595
Raises:
533-
GatewayNotFoundError: If gateway not found
596+
GatewayNotFoundError: If the gateway is not found
534597
GatewayError: For other deletion errors
598+
599+
Examples:
600+
>>> from mcpgateway.services.gateway_service import GatewayService
601+
>>> from unittest.mock import MagicMock
602+
>>> service = GatewayService()
603+
>>> db = MagicMock()
604+
>>> gateway = MagicMock()
605+
>>> db.get.return_value = gateway
606+
>>> db.delete = MagicMock()
607+
>>> db.commit = MagicMock()
608+
>>> service._notify_gateway_deleted = MagicMock()
609+
>>> import asyncio
610+
>>> try:
611+
... asyncio.run(service.delete_gateway(db, 'gateway_id'))
612+
... except Exception:
613+
... pass
535614
"""
536615
try:
537616
# Find gateway
@@ -559,7 +638,8 @@ async def delete_gateway(self, db: Session, gateway_id: str) -> None:
559638
raise GatewayError(f"Failed to delete gateway: {str(e)}")
560639

561640
async def forward_request(self, gateway: DbGateway, method: str, params: Optional[Dict[str, Any]] = None) -> Any:
562-
"""Forward a request to a gateway.
641+
"""
642+
Forward a request to a gateway.
563643
564644
Args:
565645
gateway: Gateway to forward to
@@ -572,6 +652,17 @@ async def forward_request(self, gateway: DbGateway, method: str, params: Optiona
572652
Raises:
573653
GatewayConnectionError: If forwarding fails
574654
GatewayError: If gateway gave an error
655+
656+
Examples:
657+
>>> from mcpgateway.services.gateway_service import GatewayService
658+
>>> from unittest.mock import MagicMock
659+
>>> service = GatewayService()
660+
>>> gateway = MagicMock()
661+
>>> import asyncio
662+
>>> try:
663+
... asyncio.run(service.forward_request(gateway, 'method'))
664+
... except Exception:
665+
... pass
575666
"""
576667
if not gateway.enabled:
577668
raise GatewayConnectionError(f"Cannot forward request to inactive gateway: {gateway.name}")
@@ -629,15 +720,24 @@ async def _handle_gateway_failure(self, gateway: str) -> None:
629720
self._gateway_failure_counts[gateway.id] = 0 # Reset after deactivation
630721

631722
async def check_health_of_gateways(self, gateways: List[DbGateway]) -> bool:
632-
"""Health check for a list of gateways.
633-
634-
Deactivates gateway if gateway is not healthy.
723+
"""
724+
Check health of gateways.
635725
636726
Args:
637-
gateways (List[DbGateway]): List of gateways to check if healthy
727+
gateways: List of DbGateway objects
638728
639729
Returns:
640-
bool: True if all active gateways are healthy
730+
True if all gateways are healthy, False otherwise
731+
732+
Examples:
733+
>>> from mcpgateway.services.gateway_service import GatewayService
734+
>>> from unittest.mock import MagicMock
735+
>>> service = GatewayService()
736+
>>> gateways = [MagicMock()]
737+
>>> import asyncio
738+
>>> result = asyncio.run(service.check_health_of_gateways(gateways))
739+
>>> isinstance(result, bool)
740+
True
641741
"""
642742
# Reuse a single HTTP client for all requests
643743
async with httpx.AsyncClient() as client:
@@ -676,13 +776,25 @@ async def check_health_of_gateways(self, gateways: List[DbGateway]) -> bool:
676776
return True
677777

678778
async def aggregate_capabilities(self, db: Session) -> Dict[str, Any]:
679-
"""Aggregate capabilities from all gateways.
779+
"""
780+
Aggregate capabilities across all gateways.
680781
681782
Args:
682783
db: Database session
683784
684785
Returns:
685-
Combined capabilities
786+
Dictionary of aggregated capabilities
787+
788+
Examples:
789+
>>> from mcpgateway.services.gateway_service import GatewayService
790+
>>> from unittest.mock import MagicMock
791+
>>> service = GatewayService()
792+
>>> db = MagicMock()
793+
>>> db.execute.return_value.scalars.return_value.all.return_value = [MagicMock()]
794+
>>> import asyncio
795+
>>> result = asyncio.run(service.aggregate_capabilities(db))
796+
>>> isinstance(result, dict)
797+
True
686798
"""
687799
capabilities = {
688800
"prompts": {"listChanged": True},

mcpgateway/services/logging_service.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,14 @@ def __init__(self):
3636
self._loggers: Dict[str, logging.Logger] = {}
3737

3838
async def initialize(self) -> None:
39-
"""Initialize logging service."""
39+
"""Initialize logging service.
40+
41+
Examples:
42+
>>> from mcpgateway.services.logging_service import LoggingService
43+
>>> import asyncio
44+
>>> service = LoggingService()
45+
>>> asyncio.run(service.initialize())
46+
"""
4047
# Configure root logger
4148
logging.basicConfig(
4249
level=logging.INFO,
@@ -46,7 +53,14 @@ async def initialize(self) -> None:
4653
logging.info("Logging service initialized")
4754

4855
async def shutdown(self) -> None:
49-
"""Shutdown logging service."""
56+
"""Shutdown logging service.
57+
58+
Examples:
59+
>>> from mcpgateway.services.logging_service import LoggingService
60+
>>> import asyncio
61+
>>> service = LoggingService()
62+
>>> asyncio.run(service.shutdown())
63+
"""
5064
# Clear subscribers
5165
self._subscribers.clear()
5266
logging.info("Logging service shutdown")
@@ -59,6 +73,14 @@ def get_logger(self, name: str) -> logging.Logger:
5973
6074
Returns:
6175
Logger instance
76+
77+
Examples:
78+
>>> from mcpgateway.services.logging_service import LoggingService
79+
>>> service = LoggingService()
80+
>>> logger = service.get_logger('test')
81+
>>> import logging
82+
>>> isinstance(logger, logging.Logger)
83+
True
6284
"""
6385
if name not in self._loggers:
6486
logger = logging.getLogger(name)
@@ -78,6 +100,13 @@ async def set_level(self, level: LogLevel) -> None:
78100
79101
Args:
80102
level: New log level
103+
104+
Examples:
105+
>>> from mcpgateway.services.logging_service import LoggingService
106+
>>> from mcpgateway.models import LogLevel
107+
>>> import asyncio
108+
>>> service = LoggingService()
109+
>>> asyncio.run(service.set_level(LogLevel.DEBUG))
81110
"""
82111
self._level = level
83112

@@ -95,6 +124,13 @@ async def notify(self, data: Any, level: LogLevel, logger_name: Optional[str] =
95124
data: Log message data
96125
level: Log severity level
97126
logger_name: Optional logger name
127+
128+
Examples:
129+
>>> from mcpgateway.services.logging_service import LoggingService
130+
>>> from mcpgateway.models import LogLevel
131+
>>> import asyncio
132+
>>> service = LoggingService()
133+
>>> asyncio.run(service.notify('test', LogLevel.INFO))
98134
"""
99135
# Skip if below current level
100136
if not self._should_log(level):
@@ -131,6 +167,9 @@ async def subscribe(self) -> AsyncGenerator[Dict[str, Any], None]:
131167
132168
Yields:
133169
Log message events
170+
171+
Examples:
172+
This example was removed to prevent the test runner from hanging on async generator consumption.
134173
"""
135174
queue: asyncio.Queue = asyncio.Queue()
136175
self._subscribers.append(queue)

0 commit comments

Comments
 (0)