diff --git a/src/mcp_agent/llm/augmented_llm_slow.py b/src/mcp_agent/llm/augmented_llm_slow.py new file mode 100644 index 00000000..b4026507 --- /dev/null +++ b/src/mcp_agent/llm/augmented_llm_slow.py @@ -0,0 +1,42 @@ +import asyncio +from typing import Any, List, Optional, Union + +from mcp_agent.llm.augmented_llm import ( + MessageParamT, + RequestParams, +) +from mcp_agent.llm.augmented_llm_passthrough import PassthroughLLM +from mcp_agent.llm.provider_types import Provider +from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart + + +class SlowLLM(PassthroughLLM): + """ + A specialized LLM implementation that sleeps for 3 seconds before responding like PassthroughLLM. + + This is useful for testing scenarios where you want to simulate slow responses + or for debugging timing-related issues in parallel workflows. + """ + + def __init__( + self, provider=Provider.FAST_AGENT, name: str = "Slow", **kwargs: dict[str, Any] + ) -> None: + super().__init__(name=name, provider=provider, **kwargs) + + async def generate_str( + self, + message: Union[str, MessageParamT, List[MessageParamT]], + request_params: Optional[RequestParams] = None, + ) -> str: + """Sleep for 3 seconds then return the input message as a string.""" + await asyncio.sleep(3) + return await super().generate_str(message, request_params) + + async def _apply_prompt_provider_specific( + self, + multipart_messages: List["PromptMessageMultipart"], + request_params: RequestParams | None = None, + ) -> PromptMessageMultipart: + """Sleep for 3 seconds then apply prompt like PassthroughLLM.""" + await asyncio.sleep(3) + return await super()._apply_prompt_provider_specific(multipart_messages, request_params) diff --git a/src/mcp_agent/llm/model_factory.py b/src/mcp_agent/llm/model_factory.py index 318dab00..1e7e1fe4 100644 --- a/src/mcp_agent/llm/model_factory.py +++ b/src/mcp_agent/llm/model_factory.py @@ -8,6 +8,7 @@ from mcp_agent.core.request_params import RequestParams from mcp_agent.llm.augmented_llm_passthrough import PassthroughLLM from mcp_agent.llm.augmented_llm_playback import PlaybackLLM +from mcp_agent.llm.augmented_llm_slow import SlowLLM from mcp_agent.llm.provider_types import Provider from mcp_agent.llm.providers.augmented_llm_anthropic import AnthropicAugmentedLLM from mcp_agent.llm.providers.augmented_llm_azure import AzureOpenAIAugmentedLLM @@ -29,6 +30,7 @@ Type[OpenAIAugmentedLLM], Type[PassthroughLLM], Type[PlaybackLLM], + Type[SlowLLM], Type[DeepSeekAugmentedLLM], Type[OpenRouterAugmentedLLM], Type[TensorZeroAugmentedLLM], @@ -73,6 +75,7 @@ class ModelFactory: DEFAULT_PROVIDERS = { "passthrough": Provider.FAST_AGENT, "playback": Provider.FAST_AGENT, + "slow": Provider.FAST_AGENT, "gpt-4o": Provider.OPENAI, "gpt-4o-mini": Provider.OPENAI, "gpt-4.1": Provider.OPENAI, @@ -139,6 +142,7 @@ class ModelFactory: # This overrides the provider-based class selection MODEL_SPECIFIC_CLASSES: Dict[str, LLMClass] = { "playback": PlaybackLLM, + "slow": SlowLLM, } @classmethod diff --git a/tests/integration/sampling/fastagent.config.yaml b/tests/integration/sampling/fastagent.config.yaml index 8c7cffa1..03d962cb 100644 --- a/tests/integration/sampling/fastagent.config.yaml +++ b/tests/integration/sampling/fastagent.config.yaml @@ -23,7 +23,11 @@ mcp: args: ["run", "sampling_test_server.py"] sampling: model: "passthrough" - + slow_sampling: + command: "uv" + args: ["run", "sampling_test_server.py"] + sampling: + model: "slow" sampling_test_no_config: command: "uv" args: ["run", "sampling_test_server.py"] diff --git a/tests/integration/sampling/live.py b/tests/integration/sampling/live.py index 732aebaa..93bc577a 100644 --- a/tests/integration/sampling/live.py +++ b/tests/integration/sampling/live.py @@ -7,13 +7,16 @@ # Define the agent -@fast.agent(servers=["sampling_test"]) +@fast.agent(servers=["sampling_test", "slow_sampling"]) async def main(): # use the --model command line switch or agent arguments to change model async with fast.run() as agent: result = await agent.send('***CALL_TOOL sampling_test-sample {"to_sample": "123foo"}') print(f"RESULT: {result}") + result = await agent.send('***CALL_TOOL slow_sampling-sample_parallel') + print(f"RESULT: {result}") + if __name__ == "__main__": asyncio.run(main()) diff --git a/tests/integration/sampling/sampling_test_server.py b/tests/integration/sampling/sampling_test_server.py index b26585c5..89d98d54 100644 --- a/tests/integration/sampling/sampling_test_server.py +++ b/tests/integration/sampling/sampling_test_server.py @@ -61,6 +61,47 @@ async def sample_many(ctx: Context) -> CallToolResult: return CallToolResult(content=[TextContent(type="text", text=str(result))]) +@mcp.tool() +async def sample_parallel(ctx: Context, count: int = 5) -> CallToolResult: + """Tool that makes multiple concurrent sampling requests to test parallel processing""" + try: + logger.info(f"Making {count} concurrent sampling requests") + + # Create multiple concurrent sampling requests + import asyncio + + async def _send_sampling(request: int): + return await ctx.session.create_message( + max_tokens=100, + messages=[SamplingMessage( + role="user", + content=TextContent(type="text", text=f"Parallel request {request+1}") + )], + ) + + + tasks = [] + for i in range(count): + task = _send_sampling(i) + tasks.append(task) + + # Execute all requests concurrently + results = await asyncio.gather(*[_send_sampling(i) for i in range(count)]) + + # Combine results + response_texts = [result.content.text for result in results] + combined_response = f"Completed {len(results)} parallel requests: " + ", ".join(response_texts[:3]) + if len(response_texts) > 3: + combined_response += f"... and {len(response_texts) - 3} more" + + logger.info(f"Parallel sampling completed: {combined_response}") + return CallToolResult(content=[TextContent(type="text", text=combined_response)]) + + except Exception as e: + logger.error(f"Error in sample_parallel tool: {e}", exc_info=True) + return CallToolResult(isError=True, content=[TextContent(type="text", text=f"Error: {str(e)}")]) + + if __name__ == "__main__": logger.info("Starting sampling test server...") mcp.run()