|
24 | 24 | from mcp.types import ( |
25 | 25 | CreateMessageRequestParams, |
26 | 26 | CreateMessageResult, |
| 27 | + EmbeddedResource, |
27 | 28 | GetPromptResult, |
28 | 29 | InitializeResult, |
29 | 30 | ReadResourceResult, |
@@ -144,6 +145,37 @@ async def sampling_tool(prompt: str, ctx: Context) -> str: |
144 | 145 | else: |
145 | 146 | return f"Sampling result: {str(result.content)[:100]}..." |
146 | 147 |
|
| 148 | + # Tool with sampling capability |
| 149 | + @mcp.tool(description="A tool that uses sampling to generate a resource") |
| 150 | + async def sampling_tool_resource(prompt: str, ctx: Context) -> str: |
| 151 | + await ctx.info(f"Requesting sampling for prompt: {prompt}") |
| 152 | + |
| 153 | + # Request sampling from the client |
| 154 | + result = await ctx.session.create_message( |
| 155 | + messages=[ |
| 156 | + SamplingMessage( |
| 157 | + role="user", |
| 158 | + content=EmbeddedResource( |
| 159 | + type="resource", |
| 160 | + resource=TextResourceContents( |
| 161 | + uri=AnyUrl("file://prompt"), |
| 162 | + text=prompt, |
| 163 | + mimeType="text/plain", |
| 164 | + ), |
| 165 | + ), |
| 166 | + ) |
| 167 | + ], |
| 168 | + max_tokens=100, |
| 169 | + temperature=0.7, |
| 170 | + ) |
| 171 | + |
| 172 | + await ctx.info(f"Received sampling result from model: {result.model}") |
| 173 | + # Handle different content types |
| 174 | + if result.content.type == "text": |
| 175 | + return f"Sampling result: {result.content.text[:100]}..." |
| 176 | + else: |
| 177 | + return f"Sampling result: {str(result.content)[:100]}..." |
| 178 | + |
147 | 179 | # Tool that sends notifications and logging |
148 | 180 | @mcp.tool(description="A tool that demonstrates notifications and logging") |
149 | 181 | async def notification_tool(message: str, ctx: Context) -> str: |
@@ -694,12 +726,13 @@ async def progress_callback( |
694 | 726 | assert len(collector.log_messages) > 0 |
695 | 727 |
|
696 | 728 | # 3. Test sampling tool |
697 | | - prompt = "What is the meaning of life?" |
698 | | - sampling_result = await session.call_tool("sampling_tool", {"prompt": prompt}) |
699 | | - assert len(sampling_result.content) == 1 |
700 | | - assert isinstance(sampling_result.content[0], TextContent) |
701 | | - assert "Sampling result:" in sampling_result.content[0].text |
702 | | - assert "This is a simulated LLM response" in sampling_result.content[0].text |
| 729 | + for tool in ["sampling_tool", "sampling_tool_resource"]: |
| 730 | + prompt = "What is the meaning of life?" |
| 731 | + sampling_result = await session.call_tool(tool, {"prompt": prompt}) |
| 732 | + assert len(sampling_result.content) == 1 |
| 733 | + assert isinstance(sampling_result.content[0], TextContent) |
| 734 | + assert "Sampling result:" in sampling_result.content[0].text |
| 735 | + assert "This is a simulated LLM response" in sampling_result.content[0].text |
703 | 736 |
|
704 | 737 | # Verify we received log messages from the sampling tool |
705 | 738 | assert len(collector.log_messages) > 0 |
@@ -810,6 +843,12 @@ async def sampling_callback( |
810 | 843 | # Simulate LLM response based on the input |
811 | 844 | if params.messages and isinstance(params.messages[0].content, TextContent): |
812 | 845 | input_text = params.messages[0].content.text |
| 846 | + elif ( |
| 847 | + params.messages |
| 848 | + and isinstance(params.messages[0].content, EmbeddedResource) |
| 849 | + and isinstance(params.messages[0].content.resource, TextResourceContents) |
| 850 | + ): |
| 851 | + input_text = params.messages[0].content.resource.text |
813 | 852 | else: |
814 | 853 | input_text = "No input" |
815 | 854 | response_text = f"This is a simulated LLM response to: {input_text}" |
|
0 commit comments