Skip to content

Commit 48d944a

Browse files
authored
249 achieve 100pct doctest (#474)
* Add docstring to config.py Signed-off-by: Mihai Criveti <[email protected]> * Add docstring to main.py Signed-off-by: Mihai Criveti <[email protected]> * Update docstring in schemas.py Signed-off-by: Mihai Criveti <[email protected]> * Update docstring in translate.py Signed-off-by: Mihai Criveti <[email protected]> * Alembic docstring Signed-off-by: Mihai Criveti <[email protected]> * create_jwt_token docstring Signed-off-by: Mihai Criveti <[email protected]> * session_registry docstring Signed-off-by: Mihai Criveti <[email protected]> * autoflake8 isort Signed-off-by: Mihai Criveti <[email protected]> * autoflake8 isort pylint fixes Signed-off-by: Mihai Criveti <[email protected]> * doctest for sampling.py Signed-off-by: Mihai Criveti <[email protected]> --------- Signed-off-by: Mihai Criveti <[email protected]>
1 parent ef91f2d commit 48d944a

File tree

1 file changed

+291
-3
lines changed

1 file changed

+291
-3
lines changed

mcpgateway/handlers/sampling.py

Lines changed: 291 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,38 @@
77
88
This module implements the sampling handler for MCP LLM interactions.
99
It handles model selection, sampling preferences, and message generation.
10+
11+
Examples:
12+
>>> import asyncio
13+
>>> from mcpgateway.models import ModelPreferences
14+
>>> handler = SamplingHandler()
15+
>>> asyncio.run(handler.initialize())
16+
>>>
17+
>>> # Test model selection
18+
>>> prefs = ModelPreferences(
19+
... cost_priority=0.2,
20+
... speed_priority=0.3,
21+
... intelligence_priority=0.5
22+
... )
23+
>>> handler._select_model(prefs)
24+
'claude-3-haiku'
25+
>>>
26+
>>> # Test message validation
27+
>>> msg = {
28+
... "role": "user",
29+
... "content": {"type": "text", "text": "Hello"}
30+
... }
31+
>>> handler._validate_message(msg)
32+
True
33+
>>>
34+
>>> # Test mock sampling
35+
>>> messages = [msg]
36+
>>> response = handler._mock_sample(messages)
37+
>>> print(response)
38+
You said: Hello
39+
Here is my response...
40+
>>>
41+
>>> asyncio.run(handler.shutdown())
1042
"""
1143

1244
# Standard
@@ -34,10 +66,27 @@ class SamplingHandler:
3466
- Message sampling requests
3567
- Context management
3668
- Content validation
69+
70+
Examples:
71+
>>> handler = SamplingHandler()
72+
>>> handler._supported_models['claude-3-haiku']
73+
(0.8, 0.9, 0.7)
74+
>>> len(handler._supported_models)
75+
4
3776
"""
3877

3978
def __init__(self):
40-
"""Initialize sampling handler."""
79+
"""Initialize sampling handler.
80+
81+
Examples:
82+
>>> handler = SamplingHandler()
83+
>>> isinstance(handler._supported_models, dict)
84+
True
85+
>>> 'claude-3-opus' in handler._supported_models
86+
True
87+
>>> handler._supported_models['claude-3-sonnet']
88+
(0.5, 0.7, 0.9)
89+
"""
4190
self._supported_models = {
4291
# Maps model names to capabilities scores (cost, speed, intelligence)
4392
"claude-3-haiku": (0.8, 0.9, 0.7),
@@ -47,11 +96,26 @@ def __init__(self):
4796
}
4897

4998
async def initialize(self) -> None:
50-
"""Initialize sampling handler."""
99+
"""Initialize sampling handler.
100+
101+
Examples:
102+
>>> import asyncio
103+
>>> handler = SamplingHandler()
104+
>>> asyncio.run(handler.initialize())
105+
>>> # Handler is now initialized
106+
"""
51107
logger.info("Initializing sampling handler")
52108

53109
async def shutdown(self) -> None:
54-
"""Shutdown sampling handler."""
110+
"""Shutdown sampling handler.
111+
112+
Examples:
113+
>>> import asyncio
114+
>>> handler = SamplingHandler()
115+
>>> asyncio.run(handler.initialize())
116+
>>> asyncio.run(handler.shutdown())
117+
>>> # Handler is now shut down
118+
"""
55119
logger.info("Shutting down sampling handler")
56120

57121
async def create_message(self, db: Session, request: Dict[str, Any]) -> CreateMessageResult:
@@ -66,6 +130,64 @@ async def create_message(self, db: Session, request: Dict[str, Any]) -> CreateMe
66130
67131
Raises:
68132
SamplingError: If sampling fails
133+
134+
Examples:
135+
>>> import asyncio
136+
>>> from unittest.mock import Mock
137+
>>> handler = SamplingHandler()
138+
>>> db = Mock()
139+
>>>
140+
>>> # Test with valid request
141+
>>> request = {
142+
... "messages": [{
143+
... "role": "user",
144+
... "content": {"type": "text", "text": "Hello"}
145+
... }],
146+
... "maxTokens": 100,
147+
... "modelPreferences": {
148+
... "cost_priority": 0.3,
149+
... "speed_priority": 0.3,
150+
... "intelligence_priority": 0.4
151+
... }
152+
... }
153+
>>> result = asyncio.run(handler.create_message(db, request))
154+
>>> result.role
155+
<Role.ASSISTANT: 'assistant'>
156+
>>> result.content.type
157+
'text'
158+
>>> result.stop_reason
159+
'maxTokens'
160+
>>>
161+
>>> # Test with no messages
162+
>>> bad_request = {
163+
... "messages": [],
164+
... "maxTokens": 100,
165+
... "modelPreferences": {
166+
... "cost_priority": 0.3,
167+
... "speed_priority": 0.3,
168+
... "intelligence_priority": 0.4
169+
... }
170+
... }
171+
>>> try:
172+
... asyncio.run(handler.create_message(db, bad_request))
173+
... except SamplingError as e:
174+
... print(str(e))
175+
No messages provided
176+
>>>
177+
>>> # Test with no max tokens
178+
>>> bad_request = {
179+
... "messages": [{"role": "user", "content": {"type": "text", "text": "Hi"}}],
180+
... "modelPreferences": {
181+
... "cost_priority": 0.3,
182+
... "speed_priority": 0.3,
183+
... "intelligence_priority": 0.4
184+
... }
185+
... }
186+
>>> try:
187+
... asyncio.run(handler.create_message(db, bad_request))
188+
... except SamplingError as e:
189+
... print(str(e))
190+
Max tokens not specified
69191
"""
70192
try:
71193
# Extract request parameters
@@ -121,6 +243,56 @@ def _select_model(self, preferences: ModelPreferences) -> str:
121243
122244
Raises:
123245
SamplingError: If no suitable model found
246+
247+
Examples:
248+
>>> from mcpgateway.models import ModelPreferences, ModelHint
249+
>>> handler = SamplingHandler()
250+
>>>
251+
>>> # Test intelligence priority
252+
>>> prefs = ModelPreferences(
253+
... cost_priority=1.0,
254+
... speed_priority=0.0,
255+
... intelligence_priority=1.0
256+
... )
257+
>>> handler._select_model(prefs)
258+
'claude-3-opus'
259+
>>>
260+
>>> # Test speed priority
261+
>>> prefs = ModelPreferences(
262+
... cost_priority=0.0,
263+
... speed_priority=1.0,
264+
... intelligence_priority=0.0
265+
... )
266+
>>> handler._select_model(prefs)
267+
'claude-3-haiku'
268+
>>>
269+
>>> # Test balanced preferences
270+
>>> prefs = ModelPreferences(
271+
... cost_priority=0.33,
272+
... speed_priority=0.33,
273+
... intelligence_priority=0.34
274+
... )
275+
>>> model = handler._select_model(prefs)
276+
>>> model in handler._supported_models
277+
True
278+
>>>
279+
>>> # Test with model hints
280+
>>> prefs = ModelPreferences(
281+
... hints=[ModelHint(name="opus")],
282+
... cost_priority=0.5,
283+
... speed_priority=0.3,
284+
... intelligence_priority=0.2
285+
... )
286+
>>> handler._select_model(prefs)
287+
'claude-3-opus'
288+
>>>
289+
>>> # Test empty supported models (should raise error)
290+
>>> handler._supported_models = {}
291+
>>> try:
292+
... handler._select_model(prefs)
293+
... except SamplingError as e:
294+
... print(str(e))
295+
No suitable model found
124296
"""
125297
# Check model hints first
126298
if preferences.hints:
@@ -159,6 +331,29 @@ async def _add_context(self, _db: Session, messages: List[Dict[str, Any]], _cont
159331
160332
Returns:
161333
Messages with added context
334+
335+
Examples:
336+
>>> import asyncio
337+
>>> from unittest.mock import Mock
338+
>>> handler = SamplingHandler()
339+
>>> db = Mock()
340+
>>>
341+
>>> messages = [
342+
... {"role": "user", "content": {"type": "text", "text": "Hello"}},
343+
... {"role": "assistant", "content": {"type": "text", "text": "Hi there!"}}
344+
... ]
345+
>>>
346+
>>> # Test with 'none' context type
347+
>>> result = asyncio.run(handler._add_context(db, messages, "none"))
348+
>>> result == messages
349+
True
350+
>>>
351+
>>> # Test with 'all' context type (currently returns same messages)
352+
>>> result = asyncio.run(handler._add_context(db, messages, "all"))
353+
>>> result == messages
354+
True
355+
>>> len(result)
356+
2
162357
"""
163358
# TODO: Implement context gathering based on type
164359
# For now return original messages
@@ -172,6 +367,65 @@ def _validate_message(self, message: Dict[str, Any]) -> bool:
172367
173368
Returns:
174369
True if valid
370+
371+
Examples:
372+
>>> handler = SamplingHandler()
373+
>>>
374+
>>> # Valid text message
375+
>>> msg = {"role": "user", "content": {"type": "text", "text": "Hello"}}
376+
>>> handler._validate_message(msg)
377+
True
378+
>>>
379+
>>> # Valid assistant message
380+
>>> msg = {"role": "assistant", "content": {"type": "text", "text": "Hi!"}}
381+
>>> handler._validate_message(msg)
382+
True
383+
>>>
384+
>>> # Valid image message
385+
>>> msg = {
386+
... "role": "user",
387+
... "content": {
388+
... "type": "image",
389+
... "data": "base64data",
390+
... "mime_type": "image/png"
391+
... }
392+
... }
393+
>>> handler._validate_message(msg)
394+
True
395+
>>>
396+
>>> # Missing role
397+
>>> msg = {"content": {"type": "text", "text": "Hello"}}
398+
>>> handler._validate_message(msg)
399+
False
400+
>>>
401+
>>> # Invalid role
402+
>>> msg = {"role": "system", "content": {"type": "text", "text": "Hello"}}
403+
>>> handler._validate_message(msg)
404+
False
405+
>>>
406+
>>> # Missing content
407+
>>> msg = {"role": "user"}
408+
>>> handler._validate_message(msg)
409+
False
410+
>>>
411+
>>> # Invalid content type
412+
>>> msg = {"role": "user", "content": {"type": "audio"}}
413+
>>> handler._validate_message(msg)
414+
False
415+
>>>
416+
>>> # Text content not string
417+
>>> msg = {"role": "user", "content": {"type": "text", "text": 123}}
418+
>>> handler._validate_message(msg)
419+
False
420+
>>>
421+
>>> # Image missing data
422+
>>> msg = {"role": "user", "content": {"type": "image", "mime_type": "image/png"}}
423+
>>> handler._validate_message(msg)
424+
False
425+
>>>
426+
>>> # Invalid structure
427+
>>> handler._validate_message("not a dict")
428+
False
175429
"""
176430
try:
177431
# Must have role and content
@@ -205,6 +459,40 @@ def _mock_sample(
205459
206460
Returns:
207461
Sampled response text
462+
463+
Examples:
464+
>>> handler = SamplingHandler()
465+
>>>
466+
>>> # Single user message
467+
>>> messages = [{"role": "user", "content": {"type": "text", "text": "Hello world"}}]
468+
>>> handler._mock_sample(messages)
469+
'You said: Hello world\\nHere is my response...'
470+
>>>
471+
>>> # Conversation with multiple messages
472+
>>> messages = [
473+
... {"role": "user", "content": {"type": "text", "text": "Hi"}},
474+
... {"role": "assistant", "content": {"type": "text", "text": "Hello!"}},
475+
... {"role": "user", "content": {"type": "text", "text": "How are you?"}}
476+
... ]
477+
>>> handler._mock_sample(messages)
478+
'You said: How are you?\\nHere is my response...'
479+
>>>
480+
>>> # Image message
481+
>>> messages = [{
482+
... "role": "user",
483+
... "content": {"type": "image", "data": "base64", "mime_type": "image/png"}
484+
... }]
485+
>>> handler._mock_sample(messages)
486+
'You said: I see the image you shared.\\nHere is my response...'
487+
>>>
488+
>>> # No user messages
489+
>>> messages = [{"role": "assistant", "content": {"type": "text", "text": "Hi"}}]
490+
>>> handler._mock_sample(messages)
491+
"I'm not sure what to respond to."
492+
>>>
493+
>>> # Empty messages
494+
>>> handler._mock_sample([])
495+
"I'm not sure what to respond to."
208496
"""
209497
# Extract last user message
210498
last_msg = None

0 commit comments

Comments
 (0)