|
6 | 6 | import logging |
7 | 7 | from pathlib import Path |
8 | 8 | from typing import Optional, Dict, Union |
9 | | -from fastmcp import FastMCP, Context |
10 | 9 | from pydantic import Field |
| 10 | +from fastmcp import FastMCP, Context |
| 11 | +from fastmcp.server.auth.oidc_proxy import OIDCProxy |
| 12 | +from fastmcp.server.auth import OAuthProxy, RemoteAuthProvider |
| 13 | +from fastmcp.server.auth.providers.jwt import JWTVerifier, StaticTokenVerifier |
| 14 | +from fastmcp.server.middleware.logging import LoggingMiddleware |
| 15 | +from fastmcp.server.middleware.timing import TimingMiddleware |
| 16 | +from fastmcp.server.middleware.rate_limiting import RateLimitingMiddleware |
| 17 | +from fastmcp.server.middleware.error_handling import ErrorHandlingMiddleware |
11 | 18 | from vector_mcp.retriever.retriever import RAGRetriever |
12 | 19 | from vector_mcp.retriever.pgvector_retriever import PGVectorRetriever |
13 | 20 | from vector_mcp.retriever.qdrant_retriever import QdrantRetriever |
@@ -558,20 +565,225 @@ def vector_mcp(): |
558 | 565 | default=8000, |
559 | 566 | help="Port number for HTTP transport (default: 8000)", |
560 | 567 | ) |
| 568 | + parser.add_argument( |
| 569 | + "--auth-type", |
| 570 | + default="none", |
| 571 | + choices=["none", "static", "jwt", "oauth-proxy", "oidc-proxy", "remote-oauth"], |
| 572 | + help="Authentication type for MCP server: 'none' (disabled), 'static' (internal), 'jwt' (external token verification), 'oauth-proxy', 'oidc-proxy', 'remote-oauth' (external) (default: none)", |
| 573 | + ) |
| 574 | + # JWT/Token params |
| 575 | + parser.add_argument( |
| 576 | + "--token-jwks-uri", default=None, help="JWKS URI for JWT verification" |
| 577 | + ) |
| 578 | + parser.add_argument( |
| 579 | + "--token-issuer", default=None, help="Issuer for JWT verification" |
| 580 | + ) |
| 581 | + parser.add_argument( |
| 582 | + "--token-audience", default=None, help="Audience for JWT verification" |
| 583 | + ) |
| 584 | + # OAuth Proxy params |
| 585 | + parser.add_argument( |
| 586 | + "--oauth-upstream-auth-endpoint", |
| 587 | + default=None, |
| 588 | + help="Upstream authorization endpoint for OAuth Proxy", |
| 589 | + ) |
| 590 | + parser.add_argument( |
| 591 | + "--oauth-upstream-token-endpoint", |
| 592 | + default=None, |
| 593 | + help="Upstream token endpoint for OAuth Proxy", |
| 594 | + ) |
| 595 | + parser.add_argument( |
| 596 | + "--oauth-upstream-client-id", |
| 597 | + default=None, |
| 598 | + help="Upstream client ID for OAuth Proxy", |
| 599 | + ) |
| 600 | + parser.add_argument( |
| 601 | + "--oauth-upstream-client-secret", |
| 602 | + default=None, |
| 603 | + help="Upstream client secret for OAuth Proxy", |
| 604 | + ) |
| 605 | + parser.add_argument( |
| 606 | + "--oauth-base-url", default=None, help="Base URL for OAuth Proxy" |
| 607 | + ) |
| 608 | + # OIDC Proxy params |
| 609 | + parser.add_argument( |
| 610 | + "--oidc-config-url", default=None, help="OIDC configuration URL" |
| 611 | + ) |
| 612 | + parser.add_argument("--oidc-client-id", default=None, help="OIDC client ID") |
| 613 | + parser.add_argument("--oidc-client-secret", default=None, help="OIDC client secret") |
| 614 | + parser.add_argument("--oidc-base-url", default=None, help="Base URL for OIDC Proxy") |
| 615 | + # Remote OAuth params |
| 616 | + parser.add_argument( |
| 617 | + "--remote-auth-servers", |
| 618 | + default=None, |
| 619 | + help="Comma-separated list of authorization servers for Remote OAuth", |
| 620 | + ) |
| 621 | + parser.add_argument( |
| 622 | + "--remote-base-url", default=None, help="Base URL for Remote OAuth" |
| 623 | + ) |
| 624 | + # Common |
| 625 | + parser.add_argument( |
| 626 | + "--allowed-client-redirect-uris", |
| 627 | + default=None, |
| 628 | + help="Comma-separated list of allowed client redirect URIs", |
| 629 | + ) |
| 630 | + # Eunomia params |
| 631 | + parser.add_argument( |
| 632 | + "--eunomia-type", |
| 633 | + default="none", |
| 634 | + choices=["none", "embedded", "remote"], |
| 635 | + help="Eunomia authorization type: 'none' (disabled), 'embedded' (built-in), 'remote' (external) (default: none)", |
| 636 | + ) |
| 637 | + parser.add_argument( |
| 638 | + "--eunomia-policy-file", |
| 639 | + default="mcp_policies.json", |
| 640 | + help="Policy file for embedded Eunomia (default: mcp_policies.json)", |
| 641 | + ) |
| 642 | + parser.add_argument( |
| 643 | + "--eunomia-remote-url", default=None, help="URL for remote Eunomia server" |
| 644 | + ) |
561 | 645 |
|
562 | 646 | args = parser.parse_args() |
563 | 647 |
|
564 | 648 | if args.port < 0 or args.port > 65535: |
565 | 649 | print(f"Error: Port {args.port} is out of valid range (0-65535).") |
566 | 650 | sys.exit(1) |
567 | 651 |
|
| 652 | + # Set auth based on type |
| 653 | + auth = None |
| 654 | + allowed_uris = ( |
| 655 | + args.allowed_client_redirect_uris.split(",") |
| 656 | + if args.allowed_client_redirect_uris |
| 657 | + else None |
| 658 | + ) |
| 659 | + |
| 660 | + if args.auth_type == "none": |
| 661 | + auth = None |
| 662 | + elif args.auth_type == "static": |
| 663 | + # Internal static tokens (hardcoded example) |
| 664 | + auth = StaticTokenVerifier( |
| 665 | + tokens={ |
| 666 | + "test-token": {"client_id": "test-user", "scopes": ["read", "write"]}, |
| 667 | + "admin-token": {"client_id": "admin", "scopes": ["admin"]}, |
| 668 | + } |
| 669 | + ) |
| 670 | + elif args.auth_type == "jwt": |
| 671 | + if not (args.token_jwks_uri and args.token_issuer and args.token_audience): |
| 672 | + print( |
| 673 | + "Error: jwt requires --token-jwks-uri, --token-issuer, --token-audience" |
| 674 | + ) |
| 675 | + sys.exit(1) |
| 676 | + auth = JWTVerifier( |
| 677 | + jwks_uri=args.token_jwks_uri, |
| 678 | + issuer=args.token_issuer, |
| 679 | + audience=args.token_audience, |
| 680 | + ) |
| 681 | + elif args.auth_type == "oauth-proxy": |
| 682 | + if not ( |
| 683 | + args.oauth_upstream_auth_endpoint |
| 684 | + and args.oauth_upstream_token_endpoint |
| 685 | + and args.oauth_upstream_client_id |
| 686 | + and args.oauth_upstream_client_secret |
| 687 | + and args.oauth_base_url |
| 688 | + and args.token_jwks_uri |
| 689 | + and args.token_issuer |
| 690 | + and args.token_audience |
| 691 | + ): |
| 692 | + print( |
| 693 | + "Error: oauth-proxy requires --oauth-upstream-auth-endpoint, --oauth-upstream-token-endpoint, --oauth-upstream-client-id, --oauth-upstream-client-secret, --oauth-base-url, --token-jwks-uri, --token-issuer, --token-audience" |
| 694 | + ) |
| 695 | + sys.exit(1) |
| 696 | + token_verifier = JWTVerifier( |
| 697 | + jwks_uri=args.token_jwks_uri, |
| 698 | + issuer=args.token_issuer, |
| 699 | + audience=args.token_audience, |
| 700 | + ) |
| 701 | + auth = OAuthProxy( |
| 702 | + upstream_authorization_endpoint=args.oauth_upstream_auth_endpoint, |
| 703 | + upstream_token_endpoint=args.oauth_upstream_token_endpoint, |
| 704 | + upstream_client_id=args.oauth_upstream_client_id, |
| 705 | + upstream_client_secret=args.oauth_upstream_client_secret, |
| 706 | + token_verifier=token_verifier, |
| 707 | + base_url=args.oauth_base_url, |
| 708 | + allowed_client_redirect_uris=allowed_uris, |
| 709 | + ) |
| 710 | + elif args.auth_type == "oidc-proxy": |
| 711 | + if not ( |
| 712 | + args.oidc_config_url |
| 713 | + and args.oidc_client_id |
| 714 | + and args.oidc_client_secret |
| 715 | + and args.oidc_base_url |
| 716 | + ): |
| 717 | + print( |
| 718 | + "Error: oidc-proxy requires --oidc-config-url, --oidc-client-id, --oidc-client-secret, --oidc-base-url" |
| 719 | + ) |
| 720 | + sys.exit(1) |
| 721 | + auth = OIDCProxy( |
| 722 | + config_url=args.oidc_config_url, |
| 723 | + client_id=args.oidc_client_id, |
| 724 | + client_secret=args.oidc_client_secret, |
| 725 | + base_url=args.oidc_base_url, |
| 726 | + allowed_client_redirect_uris=allowed_uris, |
| 727 | + ) |
| 728 | + elif args.auth_type == "remote-oauth": |
| 729 | + if not ( |
| 730 | + args.remote_auth_servers |
| 731 | + and args.remote_base_url |
| 732 | + and args.token_jwks_uri |
| 733 | + and args.token_issuer |
| 734 | + and args.token_audience |
| 735 | + ): |
| 736 | + print( |
| 737 | + "Error: remote-oauth requires --remote-auth-servers, --remote-base-url, --token-jwks-uri, --token-issuer, --token-audience" |
| 738 | + ) |
| 739 | + sys.exit(1) |
| 740 | + auth_servers = [url.strip() for url in args.remote_auth_servers.split(",")] |
| 741 | + token_verifier = JWTVerifier( |
| 742 | + jwks_uri=args.token_jwks_uri, |
| 743 | + issuer=args.token_issuer, |
| 744 | + audience=args.token_audience, |
| 745 | + ) |
| 746 | + auth = RemoteAuthProvider( |
| 747 | + token_verifier=token_verifier, |
| 748 | + authorization_servers=auth_servers, |
| 749 | + base_url=args.remote_base_url, |
| 750 | + ) |
| 751 | + mcp.auth = auth |
| 752 | + if args.eunomia_type != "none": |
| 753 | + from eunomia_mcp import create_eunomia_middleware |
| 754 | + |
| 755 | + if args.eunomia_type == "embedded": |
| 756 | + if not args.eunomia_policy_file: |
| 757 | + print("Error: embedded Eunomia requires --eunomia-policy-file") |
| 758 | + sys.exit(1) |
| 759 | + middleware = create_eunomia_middleware(policy_file=args.eunomia_policy_file) |
| 760 | + mcp.add_middleware(middleware) |
| 761 | + elif args.eunomia_type == "remote": |
| 762 | + if not args.eunomia_remote_url: |
| 763 | + print("Error: remote Eunomia requires --eunomia-remote-url") |
| 764 | + sys.exit(1) |
| 765 | + middleware = create_eunomia_middleware( |
| 766 | + use_remote_eunomia=args.eunomia_remote_url |
| 767 | + ) |
| 768 | + mcp.add_middleware(middleware) |
| 769 | + |
| 770 | + mcp.add_middleware( |
| 771 | + ErrorHandlingMiddleware(include_traceback=True, transform_errors=True) |
| 772 | + ) |
| 773 | + mcp.add_middleware( |
| 774 | + RateLimitingMiddleware(max_requests_per_second=10.0, burst_capacity=20) |
| 775 | + ) |
| 776 | + mcp.add_middleware(TimingMiddleware()) |
| 777 | + mcp.add_middleware(LoggingMiddleware()) |
| 778 | + |
568 | 779 | if args.transport == "stdio": |
569 | 780 | mcp.run(transport="stdio") |
570 | 781 | elif args.transport == "http": |
571 | 782 | mcp.run(transport="http", host=args.host, port=args.port) |
572 | 783 | elif args.transport == "sse": |
573 | 784 | mcp.run(transport="sse", host=args.host, port=args.port) |
574 | 785 | else: |
| 786 | + logger = logging.getLogger("Vector") |
575 | 787 | logger.error("Transport not supported") |
576 | 788 | sys.exit(1) |
577 | 789 |
|
|
0 commit comments