Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Makefile
htmlcov/
.ruff_cache/
new/
.DS_Store

# Generated endpoints file (created by editable install)
src/golf/_endpoints.py
Expand Down
22 changes: 11 additions & 11 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,39 +19,39 @@ def render_endpoints(require_env_vars: bool = True):

# Get environment variables
platform_url = os.environ.get("GOLF_PLATFORM_API_URL")
otel_url = os.environ.get("GOLF_OTEL_ENDPOINT")
otel_url = os.environ.get("GOLF_OTEL_ENDPOINT")

# For production builds, require environment variables
# For development/editable installs, use fallback values
if require_env_vars and (not platform_url or not otel_url):
raise SystemExit(
"Missing required environment variables for URL injection:\n"
" GOLF_PLATFORM_API_URL\n"
" GOLF_PLATFORM_API_URL\n"
" GOLF_OTEL_ENDPOINT\n"
"Set these before building the package."
)

# Use environment variables if available, otherwise fallback to development URLs
values = {
"PLATFORM_API_URL": platform_url or "http://localhost:8000/api/resources",
"OTEL_ENDPOINT": otel_url or "http://localhost:4318/v1/traces",
}

try:
rendered = tpl_path.read_text(encoding="utf-8").format(**values)
except KeyError as e:
raise SystemExit(f"Missing template key: {e}") from e

return rendered


class build_py(_build_py):
"""Custom build_py that renders endpoints into the build_lib (wheel contents)."""

def run(self):
# First run the normal build
super().run()

# Then render endpoints into the build_lib
# Skip env var requirement if in CI environment or if this looks like a test install
is_ci = os.environ.get("CI") or os.environ.get("GITHUB_ACTIONS")
Expand All @@ -64,11 +64,11 @@ def run(self):

class develop(_develop):
"""Custom develop for editable installs - generates endpoints file in source tree."""

def run(self):
# Run normal develop command first
super().run()

# Generate a working copy file for editable installs (use fallback URLs if env vars missing)
rendered = render_endpoints(require_env_vars=False)
# For editable installs, write into the source tree so imports work
Expand All @@ -84,4 +84,4 @@ def run(self):
"build_py": build_py,
"develop": develop,
}
)
)
38 changes: 26 additions & 12 deletions src/golf/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
create_simple_jwt_provider,
create_dev_token_provider,
)
from .registry import (
BaseProviderPlugin,
AuthProviderFactory,
get_provider_registry,
register_provider_factory,
register_provider_plugin,
)

# Re-export for backward compatibility
from .api_key import configure_api_key, get_api_key_config, is_api_key_configured
Expand Down Expand Up @@ -48,6 +55,12 @@
"create_auth_provider",
"create_simple_jwt_provider",
"create_dev_token_provider",
# Provider registry and plugins
"BaseProviderPlugin",
"AuthProviderFactory",
"get_provider_registry",
"register_provider_factory",
"register_provider_plugin",
# API key functions (backward compatibility)
"configure_api_key",
"get_api_key_config",
Expand All @@ -61,18 +74,18 @@
]

# Global storage for auth configuration
_auth_config: tuple[AuthConfig, list[str] | None] | None = None
_auth_config: AuthConfig | None = None


def configure_auth(config: AuthConfig, required_scopes: list[str] | None = None) -> None:
def configure_auth(config: AuthConfig) -> None:
"""Configure authentication for the Golf server.

This function should be called in auth.py to set up authentication
using FastMCP's modern auth providers.

Args:
config: Authentication configuration (JWT, OAuth, Static, or Remote)
required_scopes: Optional list of scopes required for all requests
The required_scopes should be specified in the config itself.

Examples:
# JWT authentication with Auth0
Expand All @@ -97,7 +110,8 @@ def configure_auth(config: AuthConfig, required_scopes: list[str] | None = None)
"client_id": "dev-client",
"scopes": ["read", "write"],
}
}
},
required_scopes=["read"],
)
)

Expand All @@ -109,11 +123,12 @@ def configure_auth(config: AuthConfig, required_scopes: list[str] | None = None)
base_url="https://your-server.example.com",
valid_scopes=["read", "write", "admin"],
default_scopes=["read"],
required_scopes=["read"],
)
)
"""
global _auth_config
_auth_config = (config, required_scopes)
_auth_config = config


def configure_jwt_auth(
Expand Down Expand Up @@ -144,7 +159,7 @@ def configure_jwt_auth(
required_scopes=required_scopes or [],
**env_vars,
)
configure_auth(config, required_scopes)
configure_auth(config)


def configure_dev_auth(
Expand Down Expand Up @@ -173,14 +188,14 @@ def configure_dev_auth(
tokens=tokens,
required_scopes=required_scopes or [],
)
configure_auth(config, required_scopes)
configure_auth(config)


def get_auth_config() -> tuple[AuthConfig, list[str] | None] | None:
def get_auth_config() -> AuthConfig | None:
"""Get the current auth configuration.

Returns:
Tuple of (auth_config, required_scopes) if configured, None otherwise
AuthConfig if configured, None otherwise
"""
return _auth_config

Expand All @@ -204,9 +219,8 @@ def create_auth_provider_from_config() -> object | None:
Returns:
FastMCP AuthProvider instance or None if not configured
"""
config_tuple = get_auth_config()
if not config_tuple:
config = get_auth_config()
if not config:
return None

config, _ = config_tuple
return create_auth_provider(config)
134 changes: 112 additions & 22 deletions src/golf/auth/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
if TYPE_CHECKING:
from fastmcp.server.auth.auth import AuthProvider
from fastmcp.server.auth import JWTVerifier, StaticTokenVerifier
from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions
from mcp.server.auth.settings import RevocationOptions

from .providers import (
AuthConfig,
Expand All @@ -18,11 +18,19 @@
OAuthServerConfig,
RemoteAuthConfig,
)
from .registry import (
get_provider_registry,
create_auth_provider_from_registry,
)


def create_auth_provider(config: AuthConfig) -> "AuthProvider":
"""Create a FastMCP AuthProvider from Golf auth configuration.

This function uses the provider registry system to allow extensibility.
Built-in providers are automatically registered, and custom providers
can be added via the registry system.

Args:
config: Golf authentication configuration

Expand All @@ -32,17 +40,23 @@ def create_auth_provider(config: AuthConfig) -> "AuthProvider":
Raises:
ValueError: If configuration is invalid
ImportError: If required dependencies are missing
KeyError: If provider type is not registered
"""
if config.provider_type == "jwt":
return _create_jwt_provider(config)
elif config.provider_type == "static":
return _create_static_provider(config)
elif config.provider_type == "oauth_server":
return _create_oauth_server_provider(config)
elif config.provider_type == "remote":
return _create_remote_provider(config)
else:
raise ValueError(f"Unknown provider type: {config.provider_type}")
try:
return create_auth_provider_from_registry(config)
except KeyError:
# Fall back to legacy dispatch for backward compatibility
# This ensures existing code continues to work during transition
if config.provider_type == "jwt":
return _create_jwt_provider(config)
elif config.provider_type == "static":
return _create_static_provider(config)
elif config.provider_type == "oauth_server":
return _create_oauth_server_provider(config)
elif config.provider_type == "remote":
return _create_remote_provider(config)
else:
raise ValueError(f"Unknown provider type: {config.provider_type}") from None


def _create_jwt_provider(config: JWTAuthConfig) -> "JWTVerifier":
Expand Down Expand Up @@ -123,21 +137,61 @@ def _create_oauth_server_provider(config: OAuthServerConfig) -> "AuthProvider":
"OAuthProvider not available in this FastMCP version. Please upgrade to FastMCP 2.11.0 or later."
) from e

# Resolve runtime values from environment variables
# Resolve runtime values from environment variables with validation
base_url = config.base_url
if config.base_url_env_var:
env_value = os.environ.get(config.base_url_env_var)
if env_value:
base_url = env_value
# Apply the same validation as the config field to env var value
try:
from urllib.parse import urlparse

env_value = env_value.strip()
parsed = urlparse(env_value)

if not parsed.scheme or not parsed.netloc:
raise ValueError(
f"Invalid base URL from environment variable {config.base_url_env_var}: '{env_value}'"
)

if parsed.scheme not in ("http", "https"):
raise ValueError(f"Base URL from environment must use http/https: '{env_value}'")

# Production HTTPS check
is_production = (
os.environ.get("GOLF_ENV", "").lower() in ("prod", "production")
or os.environ.get("NODE_ENV", "").lower() == "production"
or os.environ.get("ENVIRONMENT", "").lower() in ("prod", "production")
)

if is_production and parsed.scheme == "http":
raise ValueError(f"Base URL must use HTTPS in production: '{env_value}'")

base_url = env_value

except Exception as e:
raise ValueError(f"Invalid base URL from environment variable {config.base_url_env_var}: {e}") from e

# Additional security validations before creating provider
from urllib.parse import urlparse

# Validate final base_url
parsed_base = urlparse(base_url)
if not parsed_base.scheme or not parsed_base.netloc:
raise ValueError(f"Invalid base URL: '{base_url}'")

# Security check: prevent localhost in production
is_production = (
os.environ.get("GOLF_ENV", "").lower() in ("prod", "production")
or os.environ.get("NODE_ENV", "").lower() == "production"
or os.environ.get("ENVIRONMENT", "").lower() in ("prod", "production")
)

if is_production and parsed_base.hostname in ("localhost", "127.0.0.1", "0.0.0.0"):
raise ValueError(f"Cannot use localhost/loopback addresses in production: '{base_url}'")

# Create client registration options
# Client registration options - always disabled for security
client_reg_options = None
if config.allow_client_registration:
client_reg_options = ClientRegistrationOptions(
enabled=True,
valid_scopes=config.valid_scopes,
default_scopes=config.default_scopes,
)

# Create revocation options
revocation_options = None
Expand All @@ -163,6 +217,20 @@ def _create_remote_provider(config: RemoteAuthConfig) -> "AuthProvider":
"RemoteAuthProvider not available in this FastMCP version. Please upgrade to FastMCP 2.11.0 or later."
) from e

# Resolve runtime values from environment variables
authorization_servers = config.authorization_servers
if config.authorization_servers_env_var:
env_value = os.environ.get(config.authorization_servers_env_var)
if env_value:
# Split comma-separated values and strip whitespace
authorization_servers = [s.strip() for s in env_value.split(",")]

resource_server_url = config.resource_server_url
if config.resource_server_url_env_var:
env_value = os.environ.get(config.resource_server_url_env_var)
if env_value:
resource_server_url = env_value

# Create the underlying token verifier
token_verifier = create_auth_provider(config.token_verifier_config)

Expand All @@ -172,8 +240,8 @@ def _create_remote_provider(config: RemoteAuthConfig) -> "AuthProvider":

return RemoteAuthProvider(
token_verifier=token_verifier,
authorization_servers=config.authorization_servers,
resource_server_url=config.resource_server_url,
authorization_servers=authorization_servers,
resource_server_url=resource_server_url,
)


Expand Down Expand Up @@ -241,3 +309,25 @@ def create_dev_token_provider(
required_scopes=required_scopes or [],
)
return _create_static_provider(config)


def register_builtin_providers() -> None:
"""Register built-in authentication providers in the registry.

This function registers the standard Golf authentication providers:
- jwt: JWT token verification
- static: Static token verification (development)
- oauth_server: Full OAuth authorization server
- remote: Remote authorization server integration
"""
registry = get_provider_registry()

# Register built-in provider factories
registry.register_factory("jwt", _create_jwt_provider)
registry.register_factory("static", _create_static_provider)
registry.register_factory("oauth_server", _create_oauth_server_provider)
registry.register_factory("remote", _create_remote_provider)


# Register built-in providers when module is imported
register_builtin_providers()
Loading