|
1 | 1 | """tests.test_api_gateway.test_rest.service module.""" |
2 | 2 |
|
| 3 | +import os |
3 | 4 | import unittest |
| 5 | +from unittest import ( |
| 6 | + mock, |
| 7 | +) |
4 | 8 |
|
| 9 | +import attr |
5 | 10 | from aiohttp.test_utils import ( |
6 | 11 | AioHTTPTestCase, |
7 | 12 | unittest_run_loop, |
8 | 13 | ) |
| 14 | +from aiohttp_middlewares.cors import ( |
| 15 | + ACCESS_CONTROL_ALLOW_HEADERS, |
| 16 | + ACCESS_CONTROL_ALLOW_METHODS, |
| 17 | + ACCESS_CONTROL_ALLOW_ORIGIN, |
| 18 | + DEFAULT_ALLOW_HEADERS, |
| 19 | + DEFAULT_ALLOW_METHODS, |
| 20 | +) |
9 | 21 |
|
10 | 22 | from minos.api_gateway.common import ( |
11 | 23 | MinosConfig, |
@@ -200,5 +212,78 @@ async def test_get(self): |
200 | 212 | self.assertIn("The requested endpoint is not available.", await response.text()) |
201 | 213 |
|
202 | 214 |
|
| 215 | +class TestApiGatewayCORS(AioHTTPTestCase): |
| 216 | + CONFIG_FILE_PATH = BASE_PATH / "config.yml" |
| 217 | + TEST_DENIED_ORIGIN = "https://www.google.com" |
| 218 | + TEST_ORIGIN = "http://localhost:3000" |
| 219 | + |
| 220 | + @mock.patch.dict(os.environ, {"API_GATEWAY_CORS_ENABLED": "true"}) |
| 221 | + def setUp(self) -> None: |
| 222 | + self.config = MinosConfig(self.CONFIG_FILE_PATH) |
| 223 | + |
| 224 | + self.discovery = MockServer( |
| 225 | + host=self.config.discovery.connection.host, port=self.config.discovery.connection.port, |
| 226 | + ) |
| 227 | + self.discovery.add_json_response( |
| 228 | + "/microservices", {"address": "localhost", "port": "5568", "status": True}, |
| 229 | + ) |
| 230 | + |
| 231 | + self.microservice = MockServer(host="localhost", port=5568) |
| 232 | + self.microservice.add_json_response( |
| 233 | + "/order/5", "Microservice call correct!!!", methods=("GET", "PUT", "PATCH", "DELETE",) |
| 234 | + ) |
| 235 | + self.microservice.add_json_response("/order", "Microservice call correct!!!", methods=("POST",)) |
| 236 | + |
| 237 | + self.discovery.start() |
| 238 | + self.microservice.start() |
| 239 | + super().setUp() |
| 240 | + |
| 241 | + def tearDown(self) -> None: |
| 242 | + self.discovery.shutdown_server() |
| 243 | + self.microservice.shutdown_server() |
| 244 | + super().tearDown() |
| 245 | + |
| 246 | + async def get_application(self): |
| 247 | + """ |
| 248 | + Override the get_app method to return your application. |
| 249 | + """ |
| 250 | + rest_service = ApiGatewayRestService( |
| 251 | + address=self.config.rest.connection.host, port=self.config.rest.connection.port, config=self.config |
| 252 | + ) |
| 253 | + |
| 254 | + return await rest_service.create_application() |
| 255 | + |
| 256 | + @staticmethod |
| 257 | + def check_allow_origin( |
| 258 | + response, origin, *, allow_headers=DEFAULT_ALLOW_HEADERS, allow_methods=DEFAULT_ALLOW_METHODS, |
| 259 | + ): |
| 260 | + assert response.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == origin |
| 261 | + if allow_headers: |
| 262 | + assert response.headers[ACCESS_CONTROL_ALLOW_HEADERS] == ", ".join(allow_headers) |
| 263 | + if allow_methods: |
| 264 | + assert response.headers[ACCESS_CONTROL_ALLOW_METHODS] == ", ".join(allow_methods) |
| 265 | + |
| 266 | + @unittest_run_loop |
| 267 | + async def test_cors_enabled(self): |
| 268 | + method = "GET" |
| 269 | + extra_headers = {} |
| 270 | + expected_origin = "*" |
| 271 | + expected_allow_headers = None |
| 272 | + expected_allow_methods = None |
| 273 | + url = "/order/5?verb=GET&path=12324" |
| 274 | + |
| 275 | + kwargs = {} |
| 276 | + if expected_allow_headers is not attr.NOTHING: |
| 277 | + kwargs["allow_headers"] = expected_allow_headers |
| 278 | + if expected_allow_methods is not attr.NOTHING: |
| 279 | + kwargs["allow_methods"] = expected_allow_methods |
| 280 | + |
| 281 | + self.check_allow_origin( |
| 282 | + await self.client.request(method, url, headers={"Origin": self.TEST_ORIGIN, **extra_headers}), |
| 283 | + expected_origin, |
| 284 | + **kwargs, |
| 285 | + ) |
| 286 | + |
| 287 | + |
203 | 288 | if __name__ == "__main__": |
204 | 289 | unittest.main() |
0 commit comments