Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/mcp_agent/llm/augmented_llm_slow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import asyncio
from typing import Any, List, Optional, Union

from mcp_agent.llm.augmented_llm import (
MessageParamT,
RequestParams,
)
from mcp_agent.llm.augmented_llm_passthrough import PassthroughLLM
from mcp_agent.llm.provider_types import Provider
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart


class SlowLLM(PassthroughLLM):
"""
A specialized LLM implementation that sleeps for 3 seconds before responding like PassthroughLLM.

This is useful for testing scenarios where you want to simulate slow responses
or for debugging timing-related issues in parallel workflows.
"""

def __init__(
self, provider=Provider.FAST_AGENT, name: str = "Slow", **kwargs: dict[str, Any]
) -> None:
super().__init__(name=name, provider=provider, **kwargs)

async def generate_str(
self,
message: Union[str, MessageParamT, List[MessageParamT]],
request_params: Optional[RequestParams] = None,
) -> str:
"""Sleep for 3 seconds then return the input message as a string."""
await asyncio.sleep(3)
return await super().generate_str(message, request_params)

async def _apply_prompt_provider_specific(
self,
multipart_messages: List["PromptMessageMultipart"],
request_params: RequestParams | None = None,
) -> PromptMessageMultipart:
"""Sleep for 3 seconds then apply prompt like PassthroughLLM."""
await asyncio.sleep(3)
return await super()._apply_prompt_provider_specific(multipart_messages, request_params)
4 changes: 4 additions & 0 deletions src/mcp_agent/llm/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from mcp_agent.core.request_params import RequestParams
from mcp_agent.llm.augmented_llm_passthrough import PassthroughLLM
from mcp_agent.llm.augmented_llm_playback import PlaybackLLM
from mcp_agent.llm.augmented_llm_slow import SlowLLM
from mcp_agent.llm.provider_types import Provider
from mcp_agent.llm.providers.augmented_llm_anthropic import AnthropicAugmentedLLM
from mcp_agent.llm.providers.augmented_llm_azure import AzureOpenAIAugmentedLLM
Expand All @@ -29,6 +30,7 @@
Type[OpenAIAugmentedLLM],
Type[PassthroughLLM],
Type[PlaybackLLM],
Type[SlowLLM],
Type[DeepSeekAugmentedLLM],
Type[OpenRouterAugmentedLLM],
Type[TensorZeroAugmentedLLM],
Expand Down Expand Up @@ -73,6 +75,7 @@ class ModelFactory:
DEFAULT_PROVIDERS = {
"passthrough": Provider.FAST_AGENT,
"playback": Provider.FAST_AGENT,
"slow": Provider.FAST_AGENT,
"gpt-4o": Provider.OPENAI,
"gpt-4o-mini": Provider.OPENAI,
"gpt-4.1": Provider.OPENAI,
Expand Down Expand Up @@ -139,6 +142,7 @@ class ModelFactory:
# This overrides the provider-based class selection
MODEL_SPECIFIC_CLASSES: Dict[str, LLMClass] = {
"playback": PlaybackLLM,
"slow": SlowLLM,
}

@classmethod
Expand Down
6 changes: 5 additions & 1 deletion tests/integration/sampling/fastagent.config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ mcp:
args: ["run", "sampling_test_server.py"]
sampling:
model: "passthrough"

slow_sampling:
command: "uv"
args: ["run", "sampling_test_server.py"]
sampling:
model: "slow"
sampling_test_no_config:
command: "uv"
args: ["run", "sampling_test_server.py"]
Expand Down
5 changes: 4 additions & 1 deletion tests/integration/sampling/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@


# Define the agent
@fast.agent(servers=["sampling_test"])
@fast.agent(servers=["sampling_test", "slow_sampling"])
async def main():
# use the --model command line switch or agent arguments to change model
async with fast.run() as agent:
result = await agent.send('***CALL_TOOL sampling_test-sample {"to_sample": "123foo"}')
print(f"RESULT: {result}")

result = await agent.send('***CALL_TOOL slow_sampling-sample_parallel')
print(f"RESULT: {result}")


if __name__ == "__main__":
asyncio.run(main())
41 changes: 41 additions & 0 deletions tests/integration/sampling/sampling_test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,47 @@ async def sample_many(ctx: Context) -> CallToolResult:
return CallToolResult(content=[TextContent(type="text", text=str(result))])


@mcp.tool()
async def sample_parallel(ctx: Context, count: int = 5) -> CallToolResult:
"""Tool that makes multiple concurrent sampling requests to test parallel processing"""
try:
logger.info(f"Making {count} concurrent sampling requests")

# Create multiple concurrent sampling requests
import asyncio

async def _send_sampling(request: int):
return await ctx.session.create_message(
max_tokens=100,
messages=[SamplingMessage(
role="user",
content=TextContent(type="text", text=f"Parallel request {request+1}")
)],
)


tasks = []
for i in range(count):
task = _send_sampling(i)
tasks.append(task)

# Execute all requests concurrently
results = await asyncio.gather(*[_send_sampling(i) for i in range(count)])

# Combine results
response_texts = [result.content.text for result in results]
combined_response = f"Completed {len(results)} parallel requests: " + ", ".join(response_texts[:3])
if len(response_texts) > 3:
combined_response += f"... and {len(response_texts) - 3} more"

logger.info(f"Parallel sampling completed: {combined_response}")
return CallToolResult(content=[TextContent(type="text", text=combined_response)])

except Exception as e:
logger.error(f"Error in sample_parallel tool: {e}", exc_info=True)
return CallToolResult(isError=True, content=[TextContent(type="text", text=f"Error: {str(e)}")])


if __name__ == "__main__":
logger.info("Starting sampling test server...")
mcp.run()
Loading