Skip to content

Commit c56ff46

Browse files
committed
test: add unit tests for SEP-991 CIMD support
Add comprehensive unit tests for the Client ID Metadata Document (CIMD) functionality including: - URL validation tests for is_valid_client_metadata_url - Tests for should_use_client_metadata_url logic - Tests for create_client_info_from_metadata_url - OAuthClientProvider initialization tests with client_metadata_url - Auth flow tests verifying CIMD is used when server supports it - Auth flow tests verifying fallback to DCR when CIMD not supported
1 parent 66c783e commit c56ff46

File tree

1 file changed

+294
-0
lines changed

1 file changed

+294
-0
lines changed

tests/client/test_auth.py

Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717
from mcp.client.auth.utils import (
1818
build_oauth_authorization_server_metadata_discovery_urls,
1919
build_protected_resource_metadata_discovery_urls,
20+
create_client_info_from_metadata_url,
2021
create_oauth_metadata_request,
2122
extract_field_from_www_auth,
2223
extract_resource_metadata_from_www_auth,
2324
extract_scope_from_www_auth,
2425
get_client_metadata_scopes,
2526
handle_registration_response,
27+
is_valid_client_metadata_url,
28+
should_use_client_metadata_url,
2629
)
2730
from mcp.shared.auth import (
2831
OAuthClientInformationFull,
@@ -1783,3 +1786,294 @@ def test_extract_field_from_www_auth_invalid_cases(
17831786

17841787
result = extract_field_from_www_auth(init_response, field_name)
17851788
assert result is None, f"Should return None for {description}"
1789+
1790+
1791+
class TestCIMD:
1792+
"""Test SEP-991 Client ID Metadata Document (CIMD) support."""
1793+
1794+
@pytest.mark.parametrize(
1795+
"url,expected",
1796+
[
1797+
# Valid CIMD URLs
1798+
("https://example.com/client", True),
1799+
("https://example.com/client-metadata.json", True),
1800+
("https://example.com/path/to/client", True),
1801+
("https://example.com:8443/client", True),
1802+
# Invalid URLs - HTTP (not HTTPS)
1803+
("http://example.com/client", False),
1804+
# Invalid URLs - root path
1805+
("https://example.com", False),
1806+
("https://example.com/", False),
1807+
# Invalid URLs - None or empty
1808+
(None, False),
1809+
("", False),
1810+
],
1811+
)
1812+
def test_is_valid_client_metadata_url(self, url: str | None, expected: bool):
1813+
"""Test CIMD URL validation."""
1814+
assert is_valid_client_metadata_url(url) == expected
1815+
1816+
def test_should_use_client_metadata_url_when_server_supports(self):
1817+
"""Test that CIMD is used when server supports it and URL is provided."""
1818+
oauth_metadata = OAuthMetadata(
1819+
issuer=AnyHttpUrl("https://auth.example.com"),
1820+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
1821+
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
1822+
client_id_metadata_document_supported=True,
1823+
)
1824+
assert should_use_client_metadata_url(oauth_metadata, "https://example.com/client") is True
1825+
1826+
def test_should_not_use_client_metadata_url_when_server_does_not_support(self):
1827+
"""Test that CIMD is not used when server doesn't support it."""
1828+
oauth_metadata = OAuthMetadata(
1829+
issuer=AnyHttpUrl("https://auth.example.com"),
1830+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
1831+
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
1832+
client_id_metadata_document_supported=False,
1833+
)
1834+
assert should_use_client_metadata_url(oauth_metadata, "https://example.com/client") is False
1835+
1836+
def test_should_not_use_client_metadata_url_when_not_provided(self):
1837+
"""Test that CIMD is not used when no URL is provided."""
1838+
oauth_metadata = OAuthMetadata(
1839+
issuer=AnyHttpUrl("https://auth.example.com"),
1840+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
1841+
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
1842+
client_id_metadata_document_supported=True,
1843+
)
1844+
assert should_use_client_metadata_url(oauth_metadata, None) is False
1845+
1846+
def test_should_not_use_client_metadata_url_when_no_metadata(self):
1847+
"""Test that CIMD is not used when OAuth metadata is None."""
1848+
assert should_use_client_metadata_url(None, "https://example.com/client") is False
1849+
1850+
def test_create_client_info_from_metadata_url(self):
1851+
"""Test creating client info from CIMD URL."""
1852+
client_info = create_client_info_from_metadata_url(
1853+
"https://example.com/client",
1854+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
1855+
)
1856+
assert client_info.client_id == "https://example.com/client"
1857+
assert client_info.token_endpoint_auth_method == "none"
1858+
assert client_info.redirect_uris == [AnyUrl("http://localhost:3030/callback")]
1859+
assert client_info.client_secret is None
1860+
1861+
def test_oauth_provider_with_valid_client_metadata_url(
1862+
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
1863+
):
1864+
"""Test OAuthClientProvider initialization with valid client_metadata_url."""
1865+
1866+
async def redirect_handler(url: str) -> None:
1867+
pass # pragma: no cover
1868+
1869+
async def callback_handler() -> tuple[str, str | None]:
1870+
return "test_auth_code", "test_state" # pragma: no cover
1871+
1872+
provider = OAuthClientProvider(
1873+
server_url="https://api.example.com/v1/mcp",
1874+
client_metadata=client_metadata,
1875+
storage=mock_storage,
1876+
redirect_handler=redirect_handler,
1877+
callback_handler=callback_handler,
1878+
client_metadata_url="https://example.com/client",
1879+
)
1880+
assert provider.context.client_metadata_url == "https://example.com/client"
1881+
1882+
def test_oauth_provider_with_invalid_client_metadata_url_raises_error(
1883+
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
1884+
):
1885+
"""Test OAuthClientProvider raises error for invalid client_metadata_url."""
1886+
1887+
async def redirect_handler(url: str) -> None:
1888+
pass # pragma: no cover
1889+
1890+
async def callback_handler() -> tuple[str, str | None]:
1891+
return "test_auth_code", "test_state" # pragma: no cover
1892+
1893+
with pytest.raises(ValueError) as exc_info:
1894+
OAuthClientProvider(
1895+
server_url="https://api.example.com/v1/mcp",
1896+
client_metadata=client_metadata,
1897+
storage=mock_storage,
1898+
redirect_handler=redirect_handler,
1899+
callback_handler=callback_handler,
1900+
client_metadata_url="http://example.com/client", # HTTP instead of HTTPS
1901+
)
1902+
assert "HTTPS URL with a non-root pathname" in str(exc_info.value)
1903+
1904+
@pytest.mark.anyio
1905+
async def test_auth_flow_uses_cimd_when_server_supports(
1906+
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
1907+
):
1908+
"""Test that auth flow uses CIMD URL as client_id when server supports it."""
1909+
1910+
async def redirect_handler(url: str) -> None:
1911+
pass # pragma: no cover
1912+
1913+
async def callback_handler() -> tuple[str, str | None]:
1914+
return "test_auth_code", "test_state" # pragma: no cover
1915+
1916+
provider = OAuthClientProvider(
1917+
server_url="https://api.example.com/v1/mcp",
1918+
client_metadata=client_metadata,
1919+
storage=mock_storage,
1920+
redirect_handler=redirect_handler,
1921+
callback_handler=callback_handler,
1922+
client_metadata_url="https://example.com/client",
1923+
)
1924+
1925+
provider.context.current_tokens = None
1926+
provider.context.token_expiry_time = None
1927+
provider._initialized = True
1928+
1929+
test_request = httpx.Request("GET", "https://api.example.com/v1/mcp")
1930+
auth_flow = provider.async_auth_flow(test_request)
1931+
1932+
# First request
1933+
request = await auth_flow.__anext__()
1934+
assert "Authorization" not in request.headers
1935+
1936+
# Send 401 response
1937+
response = httpx.Response(401, headers={}, request=test_request)
1938+
1939+
# PRM discovery
1940+
prm_request = await auth_flow.asend(response)
1941+
prm_response = httpx.Response(
1942+
200,
1943+
content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}',
1944+
request=prm_request,
1945+
)
1946+
1947+
# OAuth metadata discovery
1948+
oauth_request = await auth_flow.asend(prm_response)
1949+
oauth_response = httpx.Response(
1950+
200,
1951+
content=(
1952+
b'{"issuer": "https://auth.example.com", '
1953+
b'"authorization_endpoint": "https://auth.example.com/authorize", '
1954+
b'"token_endpoint": "https://auth.example.com/token", '
1955+
b'"client_id_metadata_document_supported": true}'
1956+
),
1957+
request=oauth_request,
1958+
)
1959+
1960+
# Mock authorization
1961+
provider._perform_authorization_code_grant = mock.AsyncMock(
1962+
return_value=("test_auth_code", "test_code_verifier")
1963+
)
1964+
1965+
# Should skip DCR and go directly to token exchange
1966+
token_request = await auth_flow.asend(oauth_response)
1967+
assert token_request.method == "POST"
1968+
assert str(token_request.url) == "https://auth.example.com/token"
1969+
1970+
# Verify client_id is the CIMD URL
1971+
content = token_request.content.decode()
1972+
assert "client_id=https%3A%2F%2Fexample.com%2Fclient" in content
1973+
1974+
# Verify client info was set correctly
1975+
assert provider.context.client_info is not None
1976+
assert provider.context.client_info.client_id == "https://example.com/client"
1977+
assert provider.context.client_info.token_endpoint_auth_method == "none"
1978+
1979+
# Complete the flow
1980+
token_response = httpx.Response(
1981+
200,
1982+
content=b'{"access_token": "test_token", "token_type": "Bearer", "expires_in": 3600}',
1983+
request=token_request,
1984+
)
1985+
1986+
final_request = await auth_flow.asend(token_response)
1987+
assert final_request.headers["Authorization"] == "Bearer test_token"
1988+
1989+
final_response = httpx.Response(200, request=final_request)
1990+
try:
1991+
await auth_flow.asend(final_response)
1992+
except StopAsyncIteration:
1993+
pass
1994+
1995+
@pytest.mark.anyio
1996+
async def test_auth_flow_falls_back_to_dcr_when_no_cimd_support(
1997+
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
1998+
):
1999+
"""Test that auth flow falls back to DCR when server doesn't support CIMD."""
2000+
2001+
async def redirect_handler(url: str) -> None:
2002+
pass # pragma: no cover
2003+
2004+
async def callback_handler() -> tuple[str, str | None]:
2005+
return "test_auth_code", "test_state" # pragma: no cover
2006+
2007+
provider = OAuthClientProvider(
2008+
server_url="https://api.example.com/v1/mcp",
2009+
client_metadata=client_metadata,
2010+
storage=mock_storage,
2011+
redirect_handler=redirect_handler,
2012+
callback_handler=callback_handler,
2013+
client_metadata_url="https://example.com/client",
2014+
)
2015+
2016+
provider.context.current_tokens = None
2017+
provider.context.token_expiry_time = None
2018+
provider._initialized = True
2019+
2020+
test_request = httpx.Request("GET", "https://api.example.com/v1/mcp")
2021+
auth_flow = provider.async_auth_flow(test_request)
2022+
2023+
# First request
2024+
request = await auth_flow.__anext__()
2025+
2026+
# Send 401 response
2027+
response = httpx.Response(401, headers={}, request=test_request)
2028+
2029+
# PRM discovery
2030+
prm_request = await auth_flow.asend(response)
2031+
prm_response = httpx.Response(
2032+
200,
2033+
content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}',
2034+
request=prm_request,
2035+
)
2036+
2037+
# OAuth metadata discovery - server does NOT support CIMD
2038+
oauth_request = await auth_flow.asend(prm_response)
2039+
oauth_response = httpx.Response(
2040+
200,
2041+
content=(
2042+
b'{"issuer": "https://auth.example.com", '
2043+
b'"authorization_endpoint": "https://auth.example.com/authorize", '
2044+
b'"token_endpoint": "https://auth.example.com/token", '
2045+
b'"registration_endpoint": "https://auth.example.com/register"}'
2046+
),
2047+
request=oauth_request,
2048+
)
2049+
2050+
# Should proceed to DCR instead of skipping it
2051+
registration_request = await auth_flow.asend(oauth_response)
2052+
assert registration_request.method == "POST"
2053+
assert str(registration_request.url) == "https://auth.example.com/register"
2054+
2055+
# Complete the flow to avoid generator cleanup issues
2056+
registration_response = httpx.Response(
2057+
201,
2058+
content=b'{"client_id": "dcr_client_id", "redirect_uris": ["http://localhost:3030/callback"]}',
2059+
request=registration_request,
2060+
)
2061+
2062+
# Mock authorization
2063+
provider._perform_authorization_code_grant = mock.AsyncMock(
2064+
return_value=("test_auth_code", "test_code_verifier")
2065+
)
2066+
2067+
token_request = await auth_flow.asend(registration_response)
2068+
token_response = httpx.Response(
2069+
200,
2070+
content=b'{"access_token": "test_token", "token_type": "Bearer", "expires_in": 3600}',
2071+
request=token_request,
2072+
)
2073+
2074+
final_request = await auth_flow.asend(token_response)
2075+
final_response = httpx.Response(200, request=final_request)
2076+
try:
2077+
await auth_flow.asend(final_response)
2078+
except StopAsyncIteration:
2079+
pass

0 commit comments

Comments
 (0)