|
3 | 3 | from fastapi import FastAPI |
4 | 4 | from fastapi.testclient import TestClient |
5 | 5 | from utils import AppFactory |
| 6 | +import gzip |
| 7 | +from fastapi import Request, Response |
6 | 8 |
|
7 | 9 | app_factory = AppFactory( |
8 | 10 | oidc_discovery_url="https://example-stac-api.com/.well-known/openid-configuration" |
@@ -59,37 +61,113 @@ def test_oidc_in_openapi_spec(source_api: FastAPI, source_api_server: str): |
59 | 61 | assert "oidcAuth" in openapi.get("components", {}).get("securitySchemes", {}) |
60 | 62 |
|
61 | 63 |
|
| 64 | +# def test_oidc_in_openapi_spec_compressed(source_api: FastAPI, source_api_server: str): |
| 65 | +# """When OpenAPI spec endpoint is set, the proxied OpenAPI spec is augmented with oidc details.""" |
| 66 | + |
| 67 | +# # Create a compressed response factory |
| 68 | +# def compressed_response_factory(request: Request): |
| 69 | +# assert False |
| 70 | +# # Get the original OpenAPI spec |
| 71 | +# openapi = source_api.openapi() |
| 72 | +# compressed_data = gzip.compress(str(openapi).encode()) |
| 73 | +# return Response( |
| 74 | +# content=compressed_data, |
| 75 | +# headers={ |
| 76 | +# "Content-Encoding": "gzip", |
| 77 | +# "Content-Type": "application/json", |
| 78 | +# }, |
| 79 | +# ) |
| 80 | + |
| 81 | +# # Set the custom response factory |
| 82 | +# source_api.state.response_factory = compressed_response_factory |
| 83 | + |
| 84 | +# app = app_factory( |
| 85 | +# upstream_url=source_api_server, |
| 86 | +# openapi_spec_endpoint=source_api.openapi_url, |
| 87 | +# ) |
| 88 | +# client = TestClient(app) |
| 89 | +# response = client.get( |
| 90 | +# source_api.openapi_url, |
| 91 | +# headers={"Accept-Encoding": "gzip"}, |
| 92 | +# ) |
| 93 | +# assert response.status_code == 200 |
| 94 | +# assert response.headers.get("content-encoding") == "gzip" |
| 95 | + |
| 96 | +# openapi = response.json() # TestClient automatically decompresses |
| 97 | +# assert "info" in openapi |
| 98 | +# assert "openapi" in openapi |
| 99 | +# assert "paths" in openapi |
| 100 | +# assert "oidcAuth" in openapi.get("components", {}).get("securitySchemes", {}) |
| 101 | + |
| 102 | + |
62 | 103 | def test_oidc_in_openapi_spec_private_endpoints( |
63 | 104 | source_api: FastAPI, source_api_server: str |
64 | 105 | ): |
65 | 106 | """When OpenAPI spec endpoint is set & endpoints are marked private, those endpoints are marked private in the spec.""" |
66 | 107 | private_endpoints = { |
67 | 108 | # https://github.com/stac-api-extensions/collection-transaction/blob/v1.0.0-beta.1/README.md#methods |
| 109 | + r"^/collections$": ["POST"], |
| 110 | + r"^/collections/([^/]+)$": ["PUT", "PATCH", "DELETE"], |
| 111 | + # https://github.com/stac-api-extensions/transaction/blob/v1.0.0-rc.3/README.md#methods |
| 112 | + r"^/collections/([^/]+)/items$": ["POST"], |
| 113 | + r"^/collections/([^/]+)/items/([^/]+)$": ["PUT", "PATCH", "DELETE"], |
| 114 | + # https://stac-utils.github.io/stac-fastapi/api/stac_fastapi/extensions/third_party/bulk_transactions/#bulktransactionextension |
| 115 | + r"^/collections/([^/]+)/bulk_items$": ["POST"], |
| 116 | + } |
| 117 | + app = app_factory( |
| 118 | + upstream_url=source_api_server, |
| 119 | + openapi_spec_endpoint=source_api.openapi_url, |
| 120 | + default_public=True, |
| 121 | + private_endpoints=private_endpoints, |
| 122 | + ) |
| 123 | + client = TestClient(app) |
| 124 | + |
| 125 | + openapi = client.get(source_api.openapi_url).raise_for_status().json() |
| 126 | + |
| 127 | + expected_auth = { |
68 | 128 | "/collections": ["POST"], |
69 | 129 | "/collections/{collection_id}": ["PUT", "PATCH", "DELETE"], |
70 | | - # https://github.com/stac-api-extensions/transaction/blob/v1.0.0-rc.3/README.md#methods |
71 | 130 | "/collections/{collection_id}/items": ["POST"], |
72 | 131 | "/collections/{collection_id}/items/{item_id}": ["PUT", "PATCH", "DELETE"], |
73 | | - # https://stac-utils.github.io/stac-fastapi/api/stac_fastapi/extensions/third_party/bulk_transactions/#bulktransactionextension |
74 | 132 | "/collections/{collection_id}/bulk_items": ["POST"], |
75 | 133 | } |
| 134 | + for path, method_config in openapi["paths"].items(): |
| 135 | + for method, config in method_config.items(): |
| 136 | + security = config.get("security") |
| 137 | + path_in_expected_auth = path in expected_auth |
| 138 | + method_in_expected_auth = any( |
| 139 | + method.casefold() == m.casefold() for m in expected_auth.get(path, []) |
| 140 | + ) |
| 141 | + if security: |
| 142 | + assert path_in_expected_auth |
| 143 | + assert method_in_expected_auth |
| 144 | + else: |
| 145 | + assert not all([path_in_expected_auth, method_in_expected_auth]) |
| 146 | + |
| 147 | + |
| 148 | +def test_oidc_in_openapi_spec_public_endpoints( |
| 149 | + source_api: FastAPI, source_api_server: str |
| 150 | +): |
| 151 | + """When OpenAPI spec endpoint is set & endpoints are marked public, those endpoints are not marked private in the spec.""" |
| 152 | + public = {r"^/queryables$": ["GET"], r"^/api": ["GET"]} |
76 | 153 | app = app_factory( |
77 | 154 | upstream_url=source_api_server, |
78 | 155 | openapi_spec_endpoint=source_api.openapi_url, |
79 | | - private_endpoints=private_endpoints, |
| 156 | + default_public=False, |
| 157 | + public_endpoints=public, |
80 | 158 | ) |
81 | 159 | client = TestClient(app) |
| 160 | + |
82 | 161 | openapi = client.get(source_api.openapi_url).raise_for_status().json() |
83 | | - for path, methods in private_endpoints.items(): |
84 | | - for method in methods: |
85 | | - openapi_path = openapi["paths"].get(path) |
86 | | - assert openapi_path, f"Path {path} not found in OpenAPI spec" |
87 | | - openapi_path_method = openapi_path.get(method.lower()) |
88 | | - assert ( |
89 | | - openapi_path_method |
90 | | - ), f"Method {method.lower()!r} not found for path {path!r} in OpenAPI spec for path {path}" |
91 | | - security = openapi_path_method.get("security") |
92 | | - assert security, f"Security not found for {path!r} {method.lower()!r}" |
93 | | - assert any( |
94 | | - "oidcAuth" in s for s in security |
95 | | - ), f'No "oidcAuth" in security for {path!r} {method.lower()!r}' |
| 162 | + |
| 163 | + expected_auth = {"/queryables": ["GET"]} |
| 164 | + for path, method_config in openapi["paths"].items(): |
| 165 | + for method, config in method_config.items(): |
| 166 | + security = config.get("security") |
| 167 | + if security: |
| 168 | + assert path not in expected_auth |
| 169 | + else: |
| 170 | + assert path in expected_auth |
| 171 | + assert any( |
| 172 | + method.casefold() == m.casefold() for m in expected_auth[path] |
| 173 | + ) |
0 commit comments