|
1 | 1 | import os |
2 | 2 | from pathlib import Path |
3 | 3 |
|
| 4 | +import pytest |
4 | 5 | from langchain_core.messages import AIMessage |
5 | 6 | from langchain_core.tools import BaseTool |
| 7 | +from mcp.types import Prompt |
6 | 8 |
|
7 | 9 | from langchain_mcp_adapters.client import MultiServerMCPClient |
8 | 10 | from langchain_mcp_adapters.tools import load_mcp_tools |
@@ -161,3 +163,108 @@ async def test_get_prompt(): |
161 | 163 | assert isinstance(messages[0], AIMessage) |
162 | 164 | assert "You are a helpful assistant" in messages[0].content |
163 | 165 | assert "math, addition, multiplication" in messages[0].content |
| 166 | + |
| 167 | + |
| 168 | +async def test_get_prompts(): |
| 169 | + """Test retrieving prompts from MCP servers.""" |
| 170 | + # Get the absolute path to the server scripts |
| 171 | + current_dir = Path(__file__).parent |
| 172 | + math_server_path = os.path.join(current_dir, "servers/math_server.py") |
| 173 | + |
| 174 | + client = MultiServerMCPClient( |
| 175 | + { |
| 176 | + "math": { |
| 177 | + "command": "python3", |
| 178 | + "args": [math_server_path], |
| 179 | + "transport": "stdio", |
| 180 | + } |
| 181 | + }, |
| 182 | + ) |
| 183 | + # Test getting prompts from the math server |
| 184 | + prompts = await client.get_prompts( |
| 185 | + "math", |
| 186 | + ) |
| 187 | + |
| 188 | + # Check that we got multiple Prompts back |
| 189 | + assert len(prompts) == 3 |
| 190 | + assert all(isinstance(prompt, Prompt) for prompt in prompts) |
| 191 | + |
| 192 | + # Check the first prompt (configure_assistant) |
| 193 | + configure_prompt = [p for p in prompts if p.name == "configure_assistant"][0] |
| 194 | + assert configure_prompt.description == "" |
| 195 | + assert configure_prompt.title is None |
| 196 | + assert len(configure_prompt.arguments) == 1 |
| 197 | + assert configure_prompt.arguments[0].name == "skills" |
| 198 | + assert configure_prompt.arguments[0].required is True |
| 199 | + |
| 200 | + # Check the second prompt (math_problem_solver) |
| 201 | + solver_prompt = [p for p in prompts if p.name == "math_problem_solver"][0] |
| 202 | + assert solver_prompt.description == "" |
| 203 | + assert len(solver_prompt.arguments) == 1 |
| 204 | + assert solver_prompt.arguments[0].name == "problem_type" |
| 205 | + assert solver_prompt.arguments[0].required is True |
| 206 | + |
| 207 | + # Check the third prompt (calculation_guide) |
| 208 | + guide_prompt = [p for p in prompts if p.name == "calculation_guide"][0] |
| 209 | + assert guide_prompt.description == "" |
| 210 | + # This prompt has no arguments |
| 211 | + assert guide_prompt.arguments == [] |
| 212 | + |
| 213 | + |
| 214 | +async def test_get_prompts_invalid_server(): |
| 215 | + """Test that get_prompts raises an error for invalid server name.""" |
| 216 | + current_dir = Path(__file__).parent |
| 217 | + math_server_path = os.path.join(current_dir, "servers/math_server.py") |
| 218 | + |
| 219 | + client = MultiServerMCPClient( |
| 220 | + { |
| 221 | + "math": { |
| 222 | + "command": "python3", |
| 223 | + "args": [math_server_path], |
| 224 | + "transport": "stdio", |
| 225 | + } |
| 226 | + }, |
| 227 | + ) |
| 228 | + |
| 229 | + # Test getting prompts from a non-existent server |
| 230 | + with pytest.raises(ValueError) as exc_info: |
| 231 | + await client.get_prompts("nonexistent_server") |
| 232 | + |
| 233 | + error_msg = str(exc_info.value) |
| 234 | + assert "Couldn't find a server with name 'nonexistent_server'" in error_msg |
| 235 | + assert "expected one of '['math']'" in error_msg |
| 236 | + |
| 237 | + |
| 238 | +async def test_get_prompts_multiple_servers( |
| 239 | + socket_enabled, |
| 240 | + websocket_server, |
| 241 | + websocket_server_port: int, |
| 242 | +): |
| 243 | + """Test retrieving prompts from multiple servers.""" |
| 244 | + current_dir = Path(__file__).parent |
| 245 | + math_server_path = os.path.join(current_dir, "servers/math_server.py") |
| 246 | + weather_server_path = os.path.join(current_dir, "servers/weather_server.py") |
| 247 | + |
| 248 | + client = MultiServerMCPClient( |
| 249 | + { |
| 250 | + "math": { |
| 251 | + "command": "python3", |
| 252 | + "args": [math_server_path], |
| 253 | + "transport": "stdio", |
| 254 | + }, |
| 255 | + "weather": { |
| 256 | + "command": "python3", |
| 257 | + "args": [weather_server_path], |
| 258 | + "transport": "stdio", |
| 259 | + }, |
| 260 | + }, |
| 261 | + ) |
| 262 | + |
| 263 | + # Test getting prompts from math server |
| 264 | + math_prompts = await client.get_prompts("math") |
| 265 | + assert len(math_prompts) == 3 |
| 266 | + |
| 267 | + # Test getting prompts from weather server (may have different prompts) |
| 268 | + weather_prompts = await client.get_prompts("weather") |
| 269 | + # Weather server may or may not have prompts, just verify it doesn't crash |
| 270 | + assert isinstance(weather_prompts, list) |
0 commit comments