|
54 | 54 | import mcp
|
55 | 55 | from mcp import ClientSession
|
56 | 56 | from mcp.client.sse import sse_client
|
| 57 | +import httpx |
| 58 | +import re |
57 | 59 |
|
58 | 60 | # Import dotenv for loading environment variables
|
59 | 61 | try:
|
@@ -334,15 +336,22 @@ async def invoke_mcp_tool(mcp_registry_url: str, server_name: str, tool_name: st
|
334 | 336 | # Construct the MCP server URL from the registry URL and server name using standard URL parsing
|
335 | 337 | parsed_url = urlparse(mcp_registry_url)
|
336 | 338 |
|
337 |
| - # Extract the scheme and netloc (hostname:port) from the parsed URL |
| 339 | + # Extract the scheme, netloc and path from the parsed URL |
338 | 340 | scheme = parsed_url.scheme
|
339 | 341 | netloc = parsed_url.netloc
|
| 342 | + path = parsed_url.path |
| 343 | + |
| 344 | + # If the path ends with '/sse', remove it to get the base path |
| 345 | + if path.endswith('/sse'): |
| 346 | + base_path = path[:-4] # Remove '/sse' from the end |
| 347 | + else: |
| 348 | + base_path = path |
340 | 349 |
|
341 |
| - # Construct the base URL with scheme and netloc |
342 |
| - base_url = f"{scheme}://{netloc}" |
| 350 | + # Construct the base URL with scheme, netloc and base path |
| 351 | + base_url = f"{scheme}://{netloc}{base_path}" |
343 | 352 |
|
344 | 353 | # Create the server URL by joining the base URL with the server name and sse path
|
345 |
| - server_url = urljoin(base_url, f"{server_name}/sse") |
| 354 | + server_url = urljoin(base_url + '/', f"{server_name}/sse") |
346 | 355 | print(f"Server URL: {server_url}")
|
347 | 356 |
|
348 | 357 | # Prepare headers based on authentication method
|
@@ -399,6 +408,60 @@ def redact_sensitive_value(value: str, show_chars: int = 4) -> str:
|
399 | 408 | return "*" * len(value) if value else ""
|
400 | 409 | return value[:show_chars] + "*" * (len(value) - show_chars)
|
401 | 410 |
|
| 411 | +def normalize_sse_endpoint_url_for_request(url_str: str, original_sse_url: str) -> str: |
| 412 | + """ |
| 413 | + Normalize URLs in HTTP requests by preserving mount paths for non-mounted servers. |
| 414 | + |
| 415 | + This function only applies fixes when the request is for the same server as the original SSE URL. |
| 416 | + It should NOT modify requests to different servers (like currenttime, fininfo, etc.) |
| 417 | + |
| 418 | + Example: |
| 419 | + - Original SSE: http://localhost/mcpgw2/sse |
| 420 | + - Request to same server: http://localhost/messages/?session_id=123 -> http://localhost/mcpgw2/messages/?session_id=123 |
| 421 | + - Request to different server: http://localhost/currenttime/messages/?session_id=123 -> unchanged (already correct) |
| 422 | + """ |
| 423 | + if '/messages/' not in url_str: |
| 424 | + return url_str |
| 425 | + |
| 426 | + # Parse the original SSE URL to extract the base path |
| 427 | + from urllib.parse import urlparse |
| 428 | + parsed_original = urlparse(original_sse_url) |
| 429 | + parsed_current = urlparse(url_str) |
| 430 | + |
| 431 | + # Only apply fixes if this is the same host/port as the original SSE URL |
| 432 | + if parsed_current.netloc != parsed_original.netloc: |
| 433 | + return url_str |
| 434 | + |
| 435 | + original_path = parsed_original.path |
| 436 | + |
| 437 | + # Remove /sse from the original path to get the base mount path |
| 438 | + if original_path.endswith('/sse'): |
| 439 | + base_mount_path = original_path[:-4] # Remove '/sse' |
| 440 | + else: |
| 441 | + base_mount_path = original_path |
| 442 | + |
| 443 | + # Only apply the fix if: |
| 444 | + # 1. There is a base mount path (non-empty) |
| 445 | + # 2. The current path is exactly /messages/... (indicating it's missing the mount path) |
| 446 | + # 3. The current path doesn't already contain a mount path |
| 447 | + if (base_mount_path and |
| 448 | + parsed_current.path.startswith('/messages/') and |
| 449 | + not parsed_current.path.startswith(base_mount_path)): |
| 450 | + |
| 451 | + # The mount path is missing, we need to add it back |
| 452 | + # Reconstruct the URL with the mount path |
| 453 | + new_path = base_mount_path + parsed_current.path |
| 454 | + fixed_url = f"{parsed_current.scheme}://{parsed_current.netloc}{new_path}" |
| 455 | + if parsed_current.query: |
| 456 | + fixed_url += f"?{parsed_current.query}" |
| 457 | + if parsed_current.fragment: |
| 458 | + fixed_url += f"#{parsed_current.fragment}" |
| 459 | + |
| 460 | + logger.debug(f"Fixed mount path in request URL: {url_str} -> {fixed_url}") |
| 461 | + return fixed_url |
| 462 | + |
| 463 | + return url_str |
| 464 | + |
402 | 465 | def load_system_prompt():
|
403 | 466 | """
|
404 | 467 | Load the system prompt template from the system_prompt.txt file.
|
@@ -593,89 +656,117 @@ async def main():
|
593 | 656 | redacted_headers[k] = v
|
594 | 657 | logger.info(f"Using authentication headers: {redacted_headers}")
|
595 | 658 |
|
596 |
| - # Initialize MCP client with the server configuration and authentication headers |
597 |
| - client = MultiServerMCPClient( |
598 |
| - { |
599 |
| - "mcp_registry": { |
600 |
| - "url": server_url, |
601 |
| - "transport": "sse", |
602 |
| - "headers": auth_headers |
603 |
| - } |
604 |
| - } |
605 |
| - ) |
606 |
| - logger.info("Connected to MCP server successfully with authentication") |
607 |
| - |
608 |
| - # Get available tools from MCP and display them |
609 |
| - mcp_tools = await client.get_tools() |
610 |
| - logger.info(f"Available MCP tools: {[tool.name for tool in mcp_tools]}") |
611 |
| - |
612 |
| - # Add the calculator and invoke_mcp_tool to the tools array |
613 |
| - # The invoke_mcp_tool function already supports authentication parameters |
614 |
| - all_tools = [calculator, invoke_mcp_tool] + mcp_tools |
615 |
| - logger.info(f"All available tools: {[tool.name if hasattr(tool, 'name') else tool.__name__ for tool in all_tools]}") |
| 659 | + # Apply monkey patch to fix mount path issues in httpx requests |
| 660 | + # This fixes the issue where non-mounted servers with default paths lose their mount path |
| 661 | + # in POST requests to /messages/ endpoints |
| 662 | + original_request = httpx.AsyncClient.request |
616 | 663 |
|
617 |
| - # Create the agent with the model and all tools |
618 |
| - agent = create_react_agent( |
619 |
| - model, |
620 |
| - all_tools |
621 |
| - ) |
| 664 | + async def patched_request(self, method, url, **kwargs): |
| 665 | + # Fix mount path issues in requests |
| 666 | + if isinstance(url, str) and '/messages/' in url: |
| 667 | + url = normalize_sse_endpoint_url_for_request(url, server_url) |
| 668 | + elif hasattr(url, '__str__') and '/messages/' in str(url): |
| 669 | + url = normalize_sse_endpoint_url_for_request(str(url), server_url) |
| 670 | + return await original_request(self, method, url, **kwargs) |
622 | 671 |
|
623 |
| - # Load and format the system prompt with the current time and MCP registry URL |
624 |
| - system_prompt_template = load_system_prompt() |
| 672 | + # Apply the patch |
| 673 | + httpx.AsyncClient.request = patched_request |
| 674 | + logger.info("Applied httpx monkey patch to fix mount path issues") |
625 | 675 |
|
626 |
| - # Prepare authentication parameters for system prompt |
627 |
| - if args.use_session_cookie: |
628 |
| - system_prompt = system_prompt_template.format( |
629 |
| - current_utc_time=current_utc_time, |
630 |
| - mcp_registry_url=args.mcp_registry_url, |
631 |
| - auth_token='', # Not used for session cookie auth |
632 |
| - user_pool_id=args.user_pool_id or '', |
633 |
| - client_id=args.client_id or '', |
634 |
| - region=args.region or 'us-east-1', |
635 |
| - auth_method=auth_method, |
636 |
| - session_cookie=session_cookie |
| 676 | + try: |
| 677 | + # Initialize MCP client with the server configuration and authentication headers |
| 678 | + client = MultiServerMCPClient( |
| 679 | + { |
| 680 | + "mcp_registry": { |
| 681 | + "url": server_url, |
| 682 | + "transport": "sse", |
| 683 | + "headers": auth_headers |
| 684 | + } |
| 685 | + } |
637 | 686 | )
|
638 |
| - else: |
639 |
| - system_prompt = system_prompt_template.format( |
640 |
| - current_utc_time=current_utc_time, |
641 |
| - mcp_registry_url=args.mcp_registry_url, |
642 |
| - auth_token=access_token, |
643 |
| - user_pool_id=args.user_pool_id, |
644 |
| - client_id=args.client_id, |
645 |
| - region=args.region, |
646 |
| - auth_method=auth_method, |
647 |
| - session_cookie='' # Not used for M2M auth |
| 687 | + logger.info("Connected to MCP server successfully with authentication, server_url: " + server_url) |
| 688 | + |
| 689 | + # Get available tools from MCP and display them |
| 690 | + mcp_tools = await client.get_tools() |
| 691 | + logger.info(f"Available MCP tools: {[tool.name for tool in mcp_tools]}") |
| 692 | + |
| 693 | + # Add the calculator and invoke_mcp_tool to the tools array |
| 694 | + # The invoke_mcp_tool function already supports authentication parameters |
| 695 | + all_tools = [calculator, invoke_mcp_tool] + mcp_tools |
| 696 | + logger.info(f"All available tools: {[tool.name if hasattr(tool, 'name') else tool.__name__ for tool in all_tools]}") |
| 697 | + |
| 698 | + # Create the agent with the model and all tools |
| 699 | + agent = create_react_agent( |
| 700 | + model, |
| 701 | + all_tools |
648 | 702 | )
|
649 |
| - |
650 |
| - # Format the message with system message first |
651 |
| - formatted_messages = [ |
652 |
| - {"role": "system", "content": system_prompt}, |
653 |
| - {"role": "user", "content": args.message} |
654 |
| - ] |
655 |
| - |
656 |
| - logger.info("\nInvoking agent...\n" + "-"*40) |
657 |
| - |
658 |
| - # Invoke the agent with the formatted messages |
659 |
| - response = await agent.ainvoke({"messages": formatted_messages}) |
660 |
| - |
661 |
| - logger.info("\nResponse:" + "\n" + "-"*40) |
662 |
| - #print(response) |
663 |
| - print_agent_response(response) |
664 |
| - |
665 |
| - # Process and display the response |
666 |
| - if response and "messages" in response and response["messages"]: |
667 |
| - # Get the last message from the response |
668 |
| - last_message = response["messages"][-1] |
669 | 703 |
|
670 |
| - if isinstance(last_message, dict) and "content" in last_message: |
671 |
| - # Display the content of the response |
672 |
| - print(last_message["content"]) |
| 704 | + # Load and format the system prompt with the current time and MCP registry URL |
| 705 | + system_prompt_template = load_system_prompt() |
| 706 | + |
| 707 | + # Prepare authentication parameters for system prompt |
| 708 | + if args.use_session_cookie: |
| 709 | + system_prompt = system_prompt_template.format( |
| 710 | + current_utc_time=current_utc_time, |
| 711 | + mcp_registry_url=args.mcp_registry_url, |
| 712 | + auth_token='', # Not used for session cookie auth |
| 713 | + user_pool_id=args.user_pool_id or '', |
| 714 | + client_id=args.client_id or '', |
| 715 | + region=args.region or 'us-east-1', |
| 716 | + auth_method=auth_method, |
| 717 | + session_cookie=session_cookie |
| 718 | + ) |
673 | 719 | else:
|
674 |
| - print(str(last_message.content)) |
675 |
| - else: |
676 |
| - print("No valid response received") |
| 720 | + system_prompt = system_prompt_template.format( |
| 721 | + current_utc_time=current_utc_time, |
| 722 | + mcp_registry_url=args.mcp_registry_url, |
| 723 | + auth_token=access_token, |
| 724 | + user_pool_id=args.user_pool_id, |
| 725 | + client_id=args.client_id, |
| 726 | + region=args.region, |
| 727 | + auth_method=auth_method, |
| 728 | + session_cookie='' # Not used for M2M auth |
| 729 | + ) |
| 730 | + |
| 731 | + # Format the message with system message first |
| 732 | + formatted_messages = [ |
| 733 | + {"role": "system", "content": system_prompt}, |
| 734 | + {"role": "user", "content": args.message} |
| 735 | + ] |
| 736 | + |
| 737 | + logger.info("\nInvoking agent...\n" + "-"*40) |
| 738 | + |
| 739 | + # Invoke the agent with the formatted messages |
| 740 | + response = await agent.ainvoke({"messages": formatted_messages}) |
| 741 | + |
| 742 | + logger.info("\nResponse:" + "\n" + "-"*40) |
| 743 | + #print(response) |
| 744 | + print_agent_response(response) |
| 745 | + |
| 746 | + # Process and display the response |
| 747 | + if response and "messages" in response and response["messages"]: |
| 748 | + # Get the last message from the response |
| 749 | + last_message = response["messages"][-1] |
| 750 | + |
| 751 | + if isinstance(last_message, dict) and "content" in last_message: |
| 752 | + # Display the content of the response |
| 753 | + print(last_message["content"]) |
| 754 | + else: |
| 755 | + print(str(last_message.content)) |
| 756 | + else: |
| 757 | + print("No valid response received") |
| 758 | + |
| 759 | + finally: |
| 760 | + # Restore original httpx behavior |
| 761 | + httpx.AsyncClient.request = original_request |
| 762 | + logger.info("Restored original httpx behavior") |
677 | 763 |
|
678 | 764 | except Exception as e:
|
| 765 | + # Restore original httpx behavior in case of error |
| 766 | + try: |
| 767 | + httpx.AsyncClient.request = original_request |
| 768 | + except NameError: |
| 769 | + pass # original_request might not be defined if error occurred before monkey patch |
679 | 770 | print(f"Error: {str(e)}")
|
680 | 771 | import traceback
|
681 | 772 | print(traceback.format_exc())
|
|
0 commit comments