22Tests for refactored OAuth client authentication implementation.
33"""
44
5- import time
65import asyncio
6+ import time
7+ from unittest .mock import AsyncMock , Mock , patch
78
89import httpx
910import pytest
1011from pydantic import AnyHttpUrl , AnyUrl
11- from unittest .mock import AsyncMock , Mock , patch
1212
13- from mcp .client .auth import OAuthClientProvider , PKCEParameters
13+ from mcp .client .auth import (
14+ ClientCredentialsProvider ,
15+ OAuthClientProvider ,
16+ PKCEParameters ,
17+ TokenExchangeProvider ,
18+ )
1419from mcp .shared .auth import (
1520 OAuthClientInformationFull ,
1621 OAuthClientMetadata ,
22+ OAuthMetadata ,
1723 OAuthToken ,
1824)
1925
@@ -81,6 +87,8 @@ async def callback_handler() -> tuple[str, str | None]:
8187 redirect_handler = redirect_handler ,
8288 callback_handler = callback_handler ,
8389 )
90+
91+
8492@pytest .fixture
8593def client_credentials_metadata ():
8694 return OAuthClientMetadata (
@@ -92,6 +100,45 @@ def client_credentials_metadata():
92100 token_endpoint_auth_method = "client_secret_post" ,
93101 )
94102
103+
104+ @pytest .fixture
105+ def oauth_metadata ():
106+ return OAuthMetadata (
107+ issuer = AnyHttpUrl ("https://auth.example.com" ),
108+ authorization_endpoint = AnyHttpUrl ("https://auth.example.com/authorize" ),
109+ token_endpoint = AnyHttpUrl ("https://auth.example.com/token" ),
110+ registration_endpoint = AnyHttpUrl ("https://auth.example.com/register" ),
111+ scopes_supported = ["read" , "write" , "admin" ],
112+ response_types_supported = ["code" ],
113+ grant_types_supported = ["authorization_code" , "refresh_token" , "client_credentials" ],
114+ code_challenge_methods_supported = ["S256" ],
115+ )
116+
117+
118+ @pytest .fixture
119+ def oauth_client_info ():
120+ return OAuthClientInformationFull (
121+ client_id = "test_client_id" ,
122+ client_secret = "test_client_secret" ,
123+ redirect_uris = [AnyUrl ("http://localhost:3000/callback" )],
124+ client_name = "Test Client" ,
125+ grant_types = ["authorization_code" , "refresh_token" ],
126+ response_types = ["code" ],
127+ scope = "read write" ,
128+ )
129+
130+
131+ @pytest .fixture
132+ def oauth_token ():
133+ return OAuthToken (
134+ access_token = "test_access_token" ,
135+ token_type = "bearer" ,
136+ expires_in = 3600 ,
137+ refresh_token = "test_refresh_token" ,
138+ scope = "read write" ,
139+ )
140+
141+
95142@pytest .fixture
96143async def client_credentials_provider (client_credentials_metadata , mock_storage ):
97144 return ClientCredentialsProvider (
@@ -100,6 +147,7 @@ async def client_credentials_provider(client_credentials_metadata, mock_storage)
100147 storage = mock_storage ,
101148 )
102149
150+
103151@pytest .fixture
104152async def token_exchange_provider (client_credentials_metadata , mock_storage ):
105153 return TokenExchangeProvider (
@@ -342,6 +390,8 @@ async def test_auth_flow_with_valid_tokens(self, oauth_provider, mock_storage, v
342390 await auth_flow .asend (response )
343391 except StopAsyncIteration :
344392 pass # Expected
393+
394+
345395class TestClientCredentialsProvider :
346396 @pytest .mark .anyio
347397 async def test_request_token_success (
@@ -417,4 +467,3 @@ async def test_request_token_success(
417467
418468 mock_client .post .assert_called_once ()
419469 assert token_exchange_provider ._current_tokens .access_token == oauth_token .access_token
420-
0 commit comments