22from contextlib import asynccontextmanager
33from dataclasses import dataclass , field
44from types import SimpleNamespace
5+ from urllib .parse import parse_qs , urlparse
6+ from uuid import uuid4
57
68import pytest
79from fastapi import FastAPI
810from fastapi .testclient import TestClient
11+ from pydantic import AnyUrl
912
1013pytest .importorskip ("mcp" )
1114
1215from belgie_core .core .settings import BelgieSettings
1316from belgie_mcp .plugin import Mcp , McpPlugin
1417from belgie_mcp .verifier import BelgieOAuthTokenVerifier
18+ from belgie_oauth_server .models import OAuthClientMetadata
19+ from belgie_oauth_server .plugin import OAuthServerPlugin
20+ from belgie_oauth_server .provider import AccessToken as OAuthAccessToken , AuthorizationParams , SimpleOAuthProvider
1521from belgie_oauth_server .settings import OAuthServer
22+ from mcp .server .auth .provider import AccessToken
1623from mcp .server .mcpserver import MCPServer
1724
1825
@@ -103,6 +110,43 @@ def test_mcp_plugin_defaults_base_url_from_belgie_settings() -> None:
103110 assert str (plugin .auth .resource_server_url ) == "https://example.com/mcp"
104111
105112
113+ @pytest .mark .asyncio
114+ async def test_mcp_plugin_verifier_uses_linked_oauth_plugin_provider () -> None :
115+ settings = OAuthServer (
116+ base_url = "https://auth.local" ,
117+ redirect_uris = ["http://localhost/callback" ],
118+ client_id = "client" ,
119+ client_secret = "secret" ,
120+ default_scope = "user" ,
121+ )
122+ provider = SimpleOAuthProvider (settings , issuer_url = str (settings .issuer_url ))
123+ oauth_plugin = OAuthServerPlugin (_belgie_settings (), settings )
124+ oauth_plugin ._provider = provider
125+ plugin = McpPlugin (
126+ _belgie_settings (),
127+ Mcp (
128+ oauth = settings ,
129+ server_url = "https://mcp.local/mcp" ,
130+ ),
131+ )
132+ _ = plugin .router (SimpleNamespace (plugins = [oauth_plugin , plugin ]))
133+ token_value , stored_token = await _issue_dynamic_client_access_token (
134+ provider ,
135+ user_id = str (uuid4 ()),
136+ resource = "https://mcp.local/mcp" ,
137+ )
138+
139+ token = await plugin .token_verifier .verify_token (token_value )
140+
141+ assert token == AccessToken (
142+ token = token_value ,
143+ client_id = stored_token .client_id ,
144+ scopes = ["user" ],
145+ expires_at = stored_token .expires_at ,
146+ resource = "https://mcp.local/mcp" ,
147+ )
148+
149+
106150def test_mount_streamable_http_accepts_alias_path_without_redirect () -> None :
107151 settings = OAuthServer (
108152 base_url = "https://auth.local" ,
@@ -121,7 +165,7 @@ def test_mount_streamable_http_accepts_alias_path_without_redirect() -> None:
121165 server = MCPServer (name = "Belgie MCP" )
122166 app = _build_test_app (plugin , server )
123167
124- with TestClient (app , base_url = "http ://localhost:8000 " ) as client :
168+ with TestClient (app , base_url = "https ://example.com " ) as client :
125169 alias_response = client .post (
126170 "/mcp" ,
127171 headers = {"Content-Type" : "application/json" },
@@ -142,6 +186,81 @@ def test_mount_streamable_http_accepts_alias_path_without_redirect() -> None:
142186 assert alias_response .json () == mounted_response .json ()
143187
144188
189+ @pytest .mark .parametrize ("host_header" , ["localhost:8000" , "127.0.0.1:8000" , "[::1]:8000" ])
190+ def test_mount_streamable_http_allows_loopback_hosts (host_header : str ) -> None :
191+ plugin = _build_plugin (server_url = "http://localhost:8000/mcp" )
192+ server = MCPServer (
193+ name = "Belgie MCP" ,
194+ auth = plugin .auth ,
195+ token_verifier = _AllowingTokenVerifier (),
196+ )
197+ app = _build_test_app (plugin , server )
198+
199+ with TestClient (app , base_url = "http://localhost:8000" ) as client :
200+ response = client .post (
201+ "/mcp" ,
202+ headers = {
203+ "Authorization" : "Bearer valid-token" ,
204+ "Host" : host_header ,
205+ },
206+ json = _build_initialize_request (),
207+ follow_redirects = False ,
208+ )
209+
210+ assert response .status_code == 200
211+ assert response .headers ["content-type" ].startswith ("text/event-stream" )
212+ assert response .headers .get ("mcp-session-id" )
213+
214+
215+ def test_mount_streamable_http_allows_configured_external_host () -> None :
216+ plugin = _build_plugin (server_url = "https://example.com/mcp" )
217+ server = MCPServer (
218+ name = "Belgie MCP" ,
219+ auth = plugin .auth ,
220+ token_verifier = _AllowingTokenVerifier (),
221+ )
222+ app = _build_test_app (plugin , server )
223+
224+ with TestClient (app , base_url = "https://example.com" ) as client :
225+ response = client .post (
226+ "/mcp" ,
227+ headers = {
228+ "Authorization" : "Bearer valid-token" ,
229+ "Host" : "example.com" ,
230+ },
231+ json = _build_initialize_request (),
232+ follow_redirects = False ,
233+ )
234+
235+ assert response .status_code == 200
236+ assert response .headers ["content-type" ].startswith ("text/event-stream" )
237+ assert response .headers .get ("mcp-session-id" )
238+
239+
240+ def test_mount_streamable_http_rejects_mismatched_host () -> None :
241+ plugin = _build_plugin (server_url = "http://localhost:8000/mcp" )
242+ server = MCPServer (
243+ name = "Belgie MCP" ,
244+ auth = plugin .auth ,
245+ token_verifier = _AllowingTokenVerifier (),
246+ )
247+ app = _build_test_app (plugin , server )
248+
249+ with TestClient (app , base_url = "http://localhost:8000" ) as client :
250+ response = client .post (
251+ "/mcp" ,
252+ headers = {
253+ "Authorization" : "Bearer valid-token" ,
254+ "Host" : "example.com" ,
255+ },
256+ json = _build_initialize_request (),
257+ follow_redirects = False ,
258+ )
259+
260+ assert response .status_code == 421
261+ assert response .text == "Invalid Host header"
262+
263+
145264def test_mount_streamable_http_preserves_auth_middleware () -> None :
146265 settings = OAuthServer (
147266 base_url = "https://auth.local" ,
@@ -165,7 +284,7 @@ def test_mount_streamable_http_preserves_auth_middleware() -> None:
165284 )
166285 app = _build_test_app (plugin , server )
167286
168- with TestClient (app , base_url = "http ://localhost:8000 " ) as client :
287+ with TestClient (app , base_url = "https ://example.com " ) as client :
169288 alias_response = client .post (
170289 "/mcp" ,
171290 headers = {"Authorization" : "Bearer alias-token" },
@@ -192,7 +311,7 @@ async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
192311 yield
193312
194313 app = FastAPI (lifespan = lifespan )
195- _ = plugin .mount_streamable_http (app , server , host = "localhost" )
314+ _ = plugin .mount_streamable_http (app , server )
196315 return app
197316
198317
@@ -202,3 +321,81 @@ class _StubTokenVerifier:
202321
203322 async def verify_token (self , token : str ) -> None :
204323 self .tokens .append (token )
324+
325+
326+ @dataclass (slots = True )
327+ class _AllowingTokenVerifier :
328+ async def verify_token (self , token : str ) -> AccessToken :
329+ return AccessToken (
330+ token = token ,
331+ client_id = "client" ,
332+ scopes = ["user" ],
333+ )
334+
335+
336+ def _build_plugin (* , server_url : str ) -> McpPlugin :
337+ settings = OAuthServer (
338+ base_url = "https://auth.local" ,
339+ redirect_uris = ["http://localhost/callback" ],
340+ client_id = "client" ,
341+ client_secret = "secret" ,
342+ default_scope = "user" ,
343+ )
344+ return McpPlugin (
345+ _belgie_settings (),
346+ Mcp (
347+ oauth = settings ,
348+ server_url = server_url ,
349+ ),
350+ )
351+
352+
353+ def _build_initialize_request () -> dict [str , object ]:
354+ return {
355+ "jsonrpc" : "2.0" ,
356+ "id" : 1 ,
357+ "method" : "initialize" ,
358+ "params" : {
359+ "protocolVersion" : "2025-03-26" ,
360+ "capabilities" : {},
361+ "clientInfo" : {"name" : "test-client" , "version" : "1" },
362+ },
363+ }
364+
365+
366+ async def _issue_dynamic_client_access_token (
367+ provider : SimpleOAuthProvider ,
368+ * ,
369+ user_id : str | None = None ,
370+ resource : str | None = None ,
371+ ) -> tuple [str , OAuthAccessToken ]:
372+ client = await provider .register_client (
373+ OAuthClientMetadata (
374+ redirect_uris = [AnyUrl ("http://localhost:6274/oauth/callback" )],
375+ grant_types = ["authorization_code" , "refresh_token" ],
376+ response_types = ["code" ],
377+ scope = "user" ,
378+ token_endpoint_auth_method = "none" ,
379+ ),
380+ )
381+ state = await provider .authorize (
382+ client ,
383+ AuthorizationParams (
384+ state = None ,
385+ scopes = ["user" ],
386+ code_challenge = "test-challenge" ,
387+ redirect_uri = AnyUrl ("http://localhost:6274/oauth/callback" ),
388+ redirect_uri_provided_explicitly = True ,
389+ resource = resource ,
390+ user_id = user_id ,
391+ session_id = str (uuid4 ()),
392+ ),
393+ )
394+ redirect = await provider .issue_authorization_code (state )
395+ code = parse_qs (urlparse (redirect ).query )["code" ][0 ]
396+ authorization_code = await provider .load_authorization_code (code )
397+ assert authorization_code is not None
398+ token_response = await provider .exchange_authorization_code (authorization_code )
399+ stored_token = await provider .load_access_token (token_response .access_token )
400+ assert stored_token is not None
401+ return token_response .access_token , stored_token
0 commit comments