Skip to content

Commit 11d916a

Browse files
committed
feat: support embedded resources in sampling
1 parent 6353dd1 commit 11d916a

File tree

2 files changed

+54
-15
lines changed

2 files changed

+54
-15
lines changed

src/mcp/types.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -651,14 +651,6 @@ class ImageContent(BaseModel):
651651
model_config = ConfigDict(extra="allow")
652652

653653

654-
class SamplingMessage(BaseModel):
655-
"""Describes a message issued to or received from an LLM API."""
656-
657-
role: Role
658-
content: TextContent | ImageContent
659-
model_config = ConfigDict(extra="allow")
660-
661-
662654
class EmbeddedResource(BaseModel):
663655
"""
664656
The contents of a resource, embedded into a prompt or tool call result.
@@ -673,6 +665,14 @@ class EmbeddedResource(BaseModel):
673665
model_config = ConfigDict(extra="allow")
674666

675667

668+
class SamplingMessage(BaseModel):
669+
"""Describes a message issued to or received from an LLM API."""
670+
671+
role: Role
672+
content: TextContent | ImageContent | EmbeddedResource
673+
model_config = ConfigDict(extra="allow")
674+
675+
676676
class PromptMessage(BaseModel):
677677
"""Describes a message returned as part of a prompt."""
678678

@@ -960,7 +960,7 @@ class CreateMessageResult(Result):
960960
"""The client's response to a sampling/create_message request from the server."""
961961

962962
role: Role
963-
content: TextContent | ImageContent
963+
content: TextContent | ImageContent | EmbeddedResource
964964
model: str
965965
"""The name of the model that generated the message."""
966966
stopReason: StopReason | None = None

tests/server/fastmcp/test_integration.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from mcp.types import (
2525
CreateMessageRequestParams,
2626
CreateMessageResult,
27+
EmbeddedResource,
2728
GetPromptResult,
2829
InitializeResult,
2930
ReadResourceResult,
@@ -144,6 +145,37 @@ async def sampling_tool(prompt: str, ctx: Context) -> str:
144145
else:
145146
return f"Sampling result: {str(result.content)[:100]}..."
146147

148+
# Tool with sampling capability
149+
@mcp.tool(description="A tool that uses sampling to generate a resource")
150+
async def sampling_tool_resource(prompt: str, ctx: Context) -> str:
151+
await ctx.info(f"Requesting sampling for prompt: {prompt}")
152+
153+
# Request sampling from the client
154+
result = await ctx.session.create_message(
155+
messages=[
156+
SamplingMessage(
157+
role="user",
158+
content=EmbeddedResource(
159+
type="resource",
160+
resource=TextResourceContents(
161+
uri=AnyUrl("file://prompt"),
162+
text=prompt,
163+
mimeType="text/plain",
164+
),
165+
),
166+
)
167+
],
168+
max_tokens=100,
169+
temperature=0.7,
170+
)
171+
172+
await ctx.info(f"Received sampling result from model: {result.model}")
173+
# Handle different content types
174+
if result.content.type == "text":
175+
return f"Sampling result: {result.content.text[:100]}..."
176+
else:
177+
return f"Sampling result: {str(result.content)[:100]}..."
178+
147179
# Tool that sends notifications and logging
148180
@mcp.tool(description="A tool that demonstrates notifications and logging")
149181
async def notification_tool(message: str, ctx: Context) -> str:
@@ -694,12 +726,13 @@ async def progress_callback(
694726
assert len(collector.log_messages) > 0
695727

696728
# 3. Test sampling tool
697-
prompt = "What is the meaning of life?"
698-
sampling_result = await session.call_tool("sampling_tool", {"prompt": prompt})
699-
assert len(sampling_result.content) == 1
700-
assert isinstance(sampling_result.content[0], TextContent)
701-
assert "Sampling result:" in sampling_result.content[0].text
702-
assert "This is a simulated LLM response" in sampling_result.content[0].text
729+
for tool in ["sampling_tool", "sampling_tool_resource"]:
730+
prompt = "What is the meaning of life?"
731+
sampling_result = await session.call_tool(tool, {"prompt": prompt})
732+
assert len(sampling_result.content) == 1
733+
assert isinstance(sampling_result.content[0], TextContent)
734+
assert "Sampling result:" in sampling_result.content[0].text
735+
assert "This is a simulated LLM response" in sampling_result.content[0].text
703736

704737
# Verify we received log messages from the sampling tool
705738
assert len(collector.log_messages) > 0
@@ -810,6 +843,12 @@ async def sampling_callback(
810843
# Simulate LLM response based on the input
811844
if params.messages and isinstance(params.messages[0].content, TextContent):
812845
input_text = params.messages[0].content.text
846+
elif (
847+
params.messages
848+
and isinstance(params.messages[0].content, EmbeddedResource)
849+
and isinstance(params.messages[0].content.resource, TextResourceContents)
850+
):
851+
input_text = params.messages[0].content.resource.text
813852
else:
814853
input_text = "No input"
815854
response_text = f"This is a simulated LLM response to: {input_text}"

0 commit comments

Comments
 (0)