7
7
8
8
This module implements the sampling handler for MCP LLM interactions.
9
9
It handles model selection, sampling preferences, and message generation.
10
+
11
+ Examples:
12
+ >>> import asyncio
13
+ >>> from mcpgateway.models import ModelPreferences
14
+ >>> handler = SamplingHandler()
15
+ >>> asyncio.run(handler.initialize())
16
+ >>>
17
+ >>> # Test model selection
18
+ >>> prefs = ModelPreferences(
19
+ ... cost_priority=0.2,
20
+ ... speed_priority=0.3,
21
+ ... intelligence_priority=0.5
22
+ ... )
23
+ >>> handler._select_model(prefs)
24
+ 'claude-3-haiku'
25
+ >>>
26
+ >>> # Test message validation
27
+ >>> msg = {
28
+ ... "role": "user",
29
+ ... "content": {"type": "text", "text": "Hello"}
30
+ ... }
31
+ >>> handler._validate_message(msg)
32
+ True
33
+ >>>
34
+ >>> # Test mock sampling
35
+ >>> messages = [msg]
36
+ >>> response = handler._mock_sample(messages)
37
+ >>> print(response)
38
+ You said: Hello
39
+ Here is my response...
40
+ >>>
41
+ >>> asyncio.run(handler.shutdown())
10
42
"""
11
43
12
44
# Standard
@@ -34,10 +66,27 @@ class SamplingHandler:
34
66
- Message sampling requests
35
67
- Context management
36
68
- Content validation
69
+
70
+ Examples:
71
+ >>> handler = SamplingHandler()
72
+ >>> handler._supported_models['claude-3-haiku']
73
+ (0.8, 0.9, 0.7)
74
+ >>> len(handler._supported_models)
75
+ 4
37
76
"""
38
77
39
78
def __init__ (self ):
40
- """Initialize sampling handler."""
79
+ """Initialize sampling handler.
80
+
81
+ Examples:
82
+ >>> handler = SamplingHandler()
83
+ >>> isinstance(handler._supported_models, dict)
84
+ True
85
+ >>> 'claude-3-opus' in handler._supported_models
86
+ True
87
+ >>> handler._supported_models['claude-3-sonnet']
88
+ (0.5, 0.7, 0.9)
89
+ """
41
90
self ._supported_models = {
42
91
# Maps model names to capabilities scores (cost, speed, intelligence)
43
92
"claude-3-haiku" : (0.8 , 0.9 , 0.7 ),
@@ -47,11 +96,26 @@ def __init__(self):
47
96
}
48
97
49
98
async def initialize (self ) -> None :
50
- """Initialize sampling handler."""
99
+ """Initialize sampling handler.
100
+
101
+ Examples:
102
+ >>> import asyncio
103
+ >>> handler = SamplingHandler()
104
+ >>> asyncio.run(handler.initialize())
105
+ >>> # Handler is now initialized
106
+ """
51
107
logger .info ("Initializing sampling handler" )
52
108
53
109
async def shutdown (self ) -> None :
54
- """Shutdown sampling handler."""
110
+ """Shutdown sampling handler.
111
+
112
+ Examples:
113
+ >>> import asyncio
114
+ >>> handler = SamplingHandler()
115
+ >>> asyncio.run(handler.initialize())
116
+ >>> asyncio.run(handler.shutdown())
117
+ >>> # Handler is now shut down
118
+ """
55
119
logger .info ("Shutting down sampling handler" )
56
120
57
121
async def create_message (self , db : Session , request : Dict [str , Any ]) -> CreateMessageResult :
@@ -66,6 +130,64 @@ async def create_message(self, db: Session, request: Dict[str, Any]) -> CreateMe
66
130
67
131
Raises:
68
132
SamplingError: If sampling fails
133
+
134
+ Examples:
135
+ >>> import asyncio
136
+ >>> from unittest.mock import Mock
137
+ >>> handler = SamplingHandler()
138
+ >>> db = Mock()
139
+ >>>
140
+ >>> # Test with valid request
141
+ >>> request = {
142
+ ... "messages": [{
143
+ ... "role": "user",
144
+ ... "content": {"type": "text", "text": "Hello"}
145
+ ... }],
146
+ ... "maxTokens": 100,
147
+ ... "modelPreferences": {
148
+ ... "cost_priority": 0.3,
149
+ ... "speed_priority": 0.3,
150
+ ... "intelligence_priority": 0.4
151
+ ... }
152
+ ... }
153
+ >>> result = asyncio.run(handler.create_message(db, request))
154
+ >>> result.role
155
+ <Role.ASSISTANT: 'assistant'>
156
+ >>> result.content.type
157
+ 'text'
158
+ >>> result.stop_reason
159
+ 'maxTokens'
160
+ >>>
161
+ >>> # Test with no messages
162
+ >>> bad_request = {
163
+ ... "messages": [],
164
+ ... "maxTokens": 100,
165
+ ... "modelPreferences": {
166
+ ... "cost_priority": 0.3,
167
+ ... "speed_priority": 0.3,
168
+ ... "intelligence_priority": 0.4
169
+ ... }
170
+ ... }
171
+ >>> try:
172
+ ... asyncio.run(handler.create_message(db, bad_request))
173
+ ... except SamplingError as e:
174
+ ... print(str(e))
175
+ No messages provided
176
+ >>>
177
+ >>> # Test with no max tokens
178
+ >>> bad_request = {
179
+ ... "messages": [{"role": "user", "content": {"type": "text", "text": "Hi"}}],
180
+ ... "modelPreferences": {
181
+ ... "cost_priority": 0.3,
182
+ ... "speed_priority": 0.3,
183
+ ... "intelligence_priority": 0.4
184
+ ... }
185
+ ... }
186
+ >>> try:
187
+ ... asyncio.run(handler.create_message(db, bad_request))
188
+ ... except SamplingError as e:
189
+ ... print(str(e))
190
+ Max tokens not specified
69
191
"""
70
192
try :
71
193
# Extract request parameters
@@ -121,6 +243,56 @@ def _select_model(self, preferences: ModelPreferences) -> str:
121
243
122
244
Raises:
123
245
SamplingError: If no suitable model found
246
+
247
+ Examples:
248
+ >>> from mcpgateway.models import ModelPreferences, ModelHint
249
+ >>> handler = SamplingHandler()
250
+ >>>
251
+ >>> # Test intelligence priority
252
+ >>> prefs = ModelPreferences(
253
+ ... cost_priority=1.0,
254
+ ... speed_priority=0.0,
255
+ ... intelligence_priority=1.0
256
+ ... )
257
+ >>> handler._select_model(prefs)
258
+ 'claude-3-opus'
259
+ >>>
260
+ >>> # Test speed priority
261
+ >>> prefs = ModelPreferences(
262
+ ... cost_priority=0.0,
263
+ ... speed_priority=1.0,
264
+ ... intelligence_priority=0.0
265
+ ... )
266
+ >>> handler._select_model(prefs)
267
+ 'claude-3-haiku'
268
+ >>>
269
+ >>> # Test balanced preferences
270
+ >>> prefs = ModelPreferences(
271
+ ... cost_priority=0.33,
272
+ ... speed_priority=0.33,
273
+ ... intelligence_priority=0.34
274
+ ... )
275
+ >>> model = handler._select_model(prefs)
276
+ >>> model in handler._supported_models
277
+ True
278
+ >>>
279
+ >>> # Test with model hints
280
+ >>> prefs = ModelPreferences(
281
+ ... hints=[ModelHint(name="opus")],
282
+ ... cost_priority=0.5,
283
+ ... speed_priority=0.3,
284
+ ... intelligence_priority=0.2
285
+ ... )
286
+ >>> handler._select_model(prefs)
287
+ 'claude-3-opus'
288
+ >>>
289
+ >>> # Test empty supported models (should raise error)
290
+ >>> handler._supported_models = {}
291
+ >>> try:
292
+ ... handler._select_model(prefs)
293
+ ... except SamplingError as e:
294
+ ... print(str(e))
295
+ No suitable model found
124
296
"""
125
297
# Check model hints first
126
298
if preferences .hints :
@@ -159,6 +331,29 @@ async def _add_context(self, _db: Session, messages: List[Dict[str, Any]], _cont
159
331
160
332
Returns:
161
333
Messages with added context
334
+
335
+ Examples:
336
+ >>> import asyncio
337
+ >>> from unittest.mock import Mock
338
+ >>> handler = SamplingHandler()
339
+ >>> db = Mock()
340
+ >>>
341
+ >>> messages = [
342
+ ... {"role": "user", "content": {"type": "text", "text": "Hello"}},
343
+ ... {"role": "assistant", "content": {"type": "text", "text": "Hi there!"}}
344
+ ... ]
345
+ >>>
346
+ >>> # Test with 'none' context type
347
+ >>> result = asyncio.run(handler._add_context(db, messages, "none"))
348
+ >>> result == messages
349
+ True
350
+ >>>
351
+ >>> # Test with 'all' context type (currently returns same messages)
352
+ >>> result = asyncio.run(handler._add_context(db, messages, "all"))
353
+ >>> result == messages
354
+ True
355
+ >>> len(result)
356
+ 2
162
357
"""
163
358
# TODO: Implement context gathering based on type
164
359
# For now return original messages
@@ -172,6 +367,65 @@ def _validate_message(self, message: Dict[str, Any]) -> bool:
172
367
173
368
Returns:
174
369
True if valid
370
+
371
+ Examples:
372
+ >>> handler = SamplingHandler()
373
+ >>>
374
+ >>> # Valid text message
375
+ >>> msg = {"role": "user", "content": {"type": "text", "text": "Hello"}}
376
+ >>> handler._validate_message(msg)
377
+ True
378
+ >>>
379
+ >>> # Valid assistant message
380
+ >>> msg = {"role": "assistant", "content": {"type": "text", "text": "Hi!"}}
381
+ >>> handler._validate_message(msg)
382
+ True
383
+ >>>
384
+ >>> # Valid image message
385
+ >>> msg = {
386
+ ... "role": "user",
387
+ ... "content": {
388
+ ... "type": "image",
389
+ ... "data": "base64data",
390
+ ... "mime_type": "image/png"
391
+ ... }
392
+ ... }
393
+ >>> handler._validate_message(msg)
394
+ True
395
+ >>>
396
+ >>> # Missing role
397
+ >>> msg = {"content": {"type": "text", "text": "Hello"}}
398
+ >>> handler._validate_message(msg)
399
+ False
400
+ >>>
401
+ >>> # Invalid role
402
+ >>> msg = {"role": "system", "content": {"type": "text", "text": "Hello"}}
403
+ >>> handler._validate_message(msg)
404
+ False
405
+ >>>
406
+ >>> # Missing content
407
+ >>> msg = {"role": "user"}
408
+ >>> handler._validate_message(msg)
409
+ False
410
+ >>>
411
+ >>> # Invalid content type
412
+ >>> msg = {"role": "user", "content": {"type": "audio"}}
413
+ >>> handler._validate_message(msg)
414
+ False
415
+ >>>
416
+ >>> # Text content not string
417
+ >>> msg = {"role": "user", "content": {"type": "text", "text": 123}}
418
+ >>> handler._validate_message(msg)
419
+ False
420
+ >>>
421
+ >>> # Image missing data
422
+ >>> msg = {"role": "user", "content": {"type": "image", "mime_type": "image/png"}}
423
+ >>> handler._validate_message(msg)
424
+ False
425
+ >>>
426
+ >>> # Invalid structure
427
+ >>> handler._validate_message("not a dict")
428
+ False
175
429
"""
176
430
try :
177
431
# Must have role and content
@@ -205,6 +459,40 @@ def _mock_sample(
205
459
206
460
Returns:
207
461
Sampled response text
462
+
463
+ Examples:
464
+ >>> handler = SamplingHandler()
465
+ >>>
466
+ >>> # Single user message
467
+ >>> messages = [{"role": "user", "content": {"type": "text", "text": "Hello world"}}]
468
+ >>> handler._mock_sample(messages)
469
+ 'You said: Hello world\\ nHere is my response...'
470
+ >>>
471
+ >>> # Conversation with multiple messages
472
+ >>> messages = [
473
+ ... {"role": "user", "content": {"type": "text", "text": "Hi"}},
474
+ ... {"role": "assistant", "content": {"type": "text", "text": "Hello!"}},
475
+ ... {"role": "user", "content": {"type": "text", "text": "How are you?"}}
476
+ ... ]
477
+ >>> handler._mock_sample(messages)
478
+ 'You said: How are you?\\ nHere is my response...'
479
+ >>>
480
+ >>> # Image message
481
+ >>> messages = [{
482
+ ... "role": "user",
483
+ ... "content": {"type": "image", "data": "base64", "mime_type": "image/png"}
484
+ ... }]
485
+ >>> handler._mock_sample(messages)
486
+ 'You said: I see the image you shared.\\ nHere is my response...'
487
+ >>>
488
+ >>> # No user messages
489
+ >>> messages = [{"role": "assistant", "content": {"type": "text", "text": "Hi"}}]
490
+ >>> handler._mock_sample(messages)
491
+ "I'm not sure what to respond to."
492
+ >>>
493
+ >>> # Empty messages
494
+ >>> handler._mock_sample([])
495
+ "I'm not sure what to respond to."
208
496
"""
209
497
# Extract last user message
210
498
last_msg = None
0 commit comments