@@ -232,3 +232,282 @@ def test_no_root_path_in_openapi_spec(source_api: FastAPI, source_api_server: st
232232 assert response .status_code == 200
233233 openapi = response .json ()
234234 assert "servers" not in openapi
235+
236+
237+ def test_upstream_servers_removed_when_root_path_set (
238+ source_api : FastAPI , source_api_server : str , source_api_responses
239+ ):
240+ """When upstream API has servers field and proxy has root_path, upstream servers are removed and replaced with proxy servers."""
241+ # Configure upstream API to return a servers field
242+ upstream_servers = [{"url" : "https://upstream-api.com/stage" }]
243+ # Add the /api endpoint to the responses
244+ source_api_responses ["/api" ] = {
245+ "GET" : {
246+ "openapi" : "3.0.0" ,
247+ "info" : {"title" : "Test API" , "version" : "1.0.0" },
248+ "paths" : {},
249+ "servers" : upstream_servers ,
250+ }
251+ }
252+
253+ root_path = "/api/v1"
254+ app = app_factory (
255+ upstream_url = source_api_server ,
256+ openapi_spec_endpoint = source_api .openapi_url ,
257+ root_path = root_path ,
258+ )
259+ client = TestClient (app )
260+ response = client .get (root_path + source_api .openapi_url )
261+ assert response .status_code == 200
262+ openapi = response .json ()
263+
264+ # Verify upstream servers are removed and replaced with proxy servers
265+ assert "servers" in openapi
266+ assert openapi ["servers" ] == [{"url" : root_path }]
267+ assert openapi ["servers" ] != upstream_servers
268+
269+
270+ def test_upstream_servers_removed_when_no_root_path (
271+ source_api : FastAPI , source_api_server : str , source_api_responses
272+ ):
273+ """When upstream API has servers field and proxy has no root_path, upstream servers are removed and no servers field is added."""
274+ # Configure upstream API to return a servers field
275+ upstream_servers = [{"url" : "https://upstream-api.com/stage" }]
276+ # Add the /api endpoint to the responses
277+ source_api_responses ["/api" ] = {
278+ "GET" : {
279+ "openapi" : "3.0.0" ,
280+ "info" : {"title" : "Test API" , "version" : "1.0.0" },
281+ "paths" : {},
282+ "servers" : upstream_servers ,
283+ }
284+ }
285+
286+ app = app_factory (
287+ upstream_url = source_api_server ,
288+ openapi_spec_endpoint = source_api .openapi_url ,
289+ root_path = "" , # No root path
290+ )
291+ client = TestClient (app )
292+ response = client .get (source_api .openapi_url )
293+ assert response .status_code == 200
294+ openapi = response .json ()
295+
296+ # Verify upstream servers are removed and no servers field is added
297+ assert "servers" not in openapi
298+
299+
300+ def test_no_servers_field_when_upstream_has_none (
301+ source_api : FastAPI , source_api_server : str , source_api_responses
302+ ):
303+ """When upstream API has no servers field and proxy has no root_path, no servers field is added."""
304+ # Configure upstream API to return no servers field
305+ source_api_responses ["/api" ] = {
306+ "GET" : {
307+ "openapi" : "3.0.0" ,
308+ "info" : {"title" : "Test API" , "version" : "1.0.0" },
309+ "paths" : {},
310+ # No servers field
311+ }
312+ }
313+
314+ app = app_factory (
315+ upstream_url = source_api_server ,
316+ openapi_spec_endpoint = source_api .openapi_url ,
317+ root_path = "" , # No root path
318+ )
319+ client = TestClient (app )
320+ response = client .get (source_api .openapi_url )
321+ assert response .status_code == 200
322+ openapi = response .json ()
323+
324+ # Verify no servers field is added
325+ assert "servers" not in openapi
326+
327+
328+ def test_multiple_upstream_servers_removed (
329+ source_api : FastAPI , source_api_server : str , source_api_responses
330+ ):
331+ """When upstream API has multiple servers, all are removed and replaced with proxy server."""
332+ # Configure upstream API to return multiple servers
333+ upstream_servers = [
334+ {"url" : "https://upstream-api.com/stage" },
335+ {"url" : "https://upstream-api.com/prod" },
336+ {
337+ "url" : "https://staging.upstream-api.com" ,
338+ "description" : "Staging environment" ,
339+ },
340+ ]
341+ source_api_responses ["/api" ] = {
342+ "GET" : {
343+ "openapi" : "3.0.0" ,
344+ "info" : {"title" : "Test API" , "version" : "1.0.0" },
345+ "paths" : {},
346+ "servers" : upstream_servers ,
347+ }
348+ }
349+
350+ root_path = "/api/v1"
351+ app = app_factory (
352+ upstream_url = source_api_server ,
353+ openapi_spec_endpoint = source_api .openapi_url ,
354+ root_path = root_path ,
355+ )
356+ client = TestClient (app )
357+ response = client .get (root_path + source_api .openapi_url )
358+ assert response .status_code == 200
359+ openapi = response .json ()
360+
361+ # Verify all upstream servers are removed and replaced with proxy server
362+ assert "servers" in openapi
363+ assert openapi ["servers" ] == [{"url" : root_path }]
364+ assert len (openapi ["servers" ]) == 1
365+ assert openapi ["servers" ] != upstream_servers
366+
367+
368+ def test_upstream_servers_with_variables_removed (
369+ source_api : FastAPI , source_api_server : str , source_api_responses
370+ ):
371+ """When upstream API has servers with variables, they are removed and replaced with proxy server."""
372+ # Configure upstream API to return servers with variables
373+ upstream_servers = [
374+ {
375+ "url" : "https://{environment}.upstream-api.com/{version}" ,
376+ "variables" : {
377+ "environment" : {"default" : "prod" , "enum" : ["dev" , "staging" , "prod" ]},
378+ "version" : {"default" : "v1" , "enum" : ["v1" , "v2" ]},
379+ },
380+ }
381+ ]
382+ source_api_responses ["/api" ] = {
383+ "GET" : {
384+ "openapi" : "3.0.0" ,
385+ "info" : {"title" : "Test API" , "version" : "1.0.0" },
386+ "paths" : {},
387+ "servers" : upstream_servers ,
388+ }
389+ }
390+
391+ root_path = "/api/v1"
392+ app = app_factory (
393+ upstream_url = source_api_server ,
394+ openapi_spec_endpoint = source_api .openapi_url ,
395+ root_path = root_path ,
396+ )
397+ client = TestClient (app )
398+ response = client .get (root_path + source_api .openapi_url )
399+ assert response .status_code == 200
400+ openapi = response .json ()
401+
402+ # Verify upstream servers with variables are removed and replaced with proxy server
403+ assert "servers" in openapi
404+ assert openapi ["servers" ] == [{"url" : root_path }]
405+ assert len (openapi ["servers" ]) == 1
406+ assert openapi ["servers" ] != upstream_servers
407+
408+
409+ def test_malformed_servers_field_handled (
410+ source_api : FastAPI , source_api_server : str , source_api_responses
411+ ):
412+ """When upstream API has malformed servers field, it is removed and replaced with proxy server."""
413+ # Configure upstream API to return malformed servers field
414+ source_api_responses ["/api" ] = {
415+ "GET" : {
416+ "openapi" : "3.0.0" ,
417+ "info" : {"title" : "Test API" , "version" : "1.0.0" },
418+ "paths" : {},
419+ "servers" : "invalid_servers_field" , # Should be a list
420+ }
421+ }
422+
423+ root_path = "/api/v1"
424+ app = app_factory (
425+ upstream_url = source_api_server ,
426+ openapi_spec_endpoint = source_api .openapi_url ,
427+ root_path = root_path ,
428+ )
429+ client = TestClient (app )
430+ response = client .get (root_path + source_api .openapi_url )
431+ assert response .status_code == 200
432+ openapi = response .json ()
433+
434+ # Verify malformed servers field is removed and replaced with proxy server
435+ assert "servers" in openapi
436+ assert openapi ["servers" ] == [{"url" : root_path }]
437+ assert isinstance (openapi ["servers" ], list )
438+
439+
440+ def test_empty_servers_list_removed (
441+ source_api : FastAPI , source_api_server : str , source_api_responses
442+ ):
443+ """When upstream API has empty servers list, it is removed and replaced with proxy server."""
444+ # Configure upstream API to return empty servers list
445+ source_api_responses ["/api" ] = {
446+ "GET" : {
447+ "openapi" : "3.0.0" ,
448+ "info" : {"title" : "Test API" , "version" : "1.0.0" },
449+ "paths" : {},
450+ "servers" : [], # Empty list
451+ }
452+ }
453+
454+ root_path = "/api/v1"
455+ app = app_factory (
456+ upstream_url = source_api_server ,
457+ openapi_spec_endpoint = source_api .openapi_url ,
458+ root_path = root_path ,
459+ )
460+ client = TestClient (app )
461+ response = client .get (root_path + source_api .openapi_url )
462+ assert response .status_code == 200
463+ openapi = response .json ()
464+
465+ # Verify empty servers list is removed and replaced with proxy server
466+ assert "servers" in openapi
467+ assert openapi ["servers" ] == [{"url" : root_path }]
468+ assert len (openapi ["servers" ]) == 1
469+
470+
471+ @pytest .mark .parametrize ("root_path" , [None , "/api/v1" ])
472+ def test_servers_are_replaced_with_proxy_server (root_path : str ):
473+ """Test that verifies upstream servers are replaced with proxy server."""
474+ from unittest .mock import Mock
475+
476+ from stac_auth_proxy .middleware .UpdateOpenApiMiddleware import OpenApiMiddleware
477+
478+ # Test data with upstream servers
479+ test_data = {
480+ "openapi" : "3.0.0" ,
481+ "info" : {"title" : "Test API" , "version" : "1.0.0" },
482+ "paths" : {},
483+ "servers" : [
484+ {"url" : "https://upstream-api.com/stage" },
485+ {"url" : "https://upstream-api.com/prod" },
486+ ],
487+ }
488+
489+ # Create middleware instance
490+ middleware = OpenApiMiddleware (
491+ app = Mock (),
492+ openapi_spec_path = "/api" ,
493+ oidc_discovery_url = "https://example.com/.well-known/openid-configuration" ,
494+ private_endpoints = {},
495+ public_endpoints = {},
496+ default_public = True ,
497+ root_path = root_path ,
498+ )
499+
500+ # Test the middleware behavior
501+ result = middleware .transform_json (test_data .copy (), Mock ())
502+
503+ # Verify that only the proxy server remains
504+ if root_path :
505+ assert "servers" in result
506+ assert len (result ["servers" ]) == 1
507+ assert result ["servers" ][0 ]["url" ] == root_path
508+ else :
509+ assert "servers" not in result
510+
511+ # Verify upstream servers are gone
512+ for server in test_data ["servers" ]:
513+ assert server not in result .get ("servers" , [])
0 commit comments