diff --git a/requirements.txt b/requirements.txt index 4a17116..402e021 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,8 @@ python-dotenv==0.20.0 # Runtime dependencies gunicorn==20.1.0 honcho==1.1.0 +flask-talisman==0.7.0 +flask-cors==5.0.1 # Code quality pylint==2.14.0 diff --git a/service/__init__.py b/service/__init__.py index a62a9b3..5041638 100644 --- a/service/__init__.py +++ b/service/__init__.py @@ -8,11 +8,16 @@ from flask import Flask from service import config from service.common import log_handlers +from flask_talisman import Talisman +from flask_cors import CORS # Create Flask application app = Flask(__name__) app.config.from_object(config) +talisman = Talisman(app) +CORS(app) + # Import the routes After the Flask app is created # pylint: disable=wrong-import-position, cyclic-import, wrong-import-order from service import routes, models # noqa: F401 E402 diff --git a/tests/test_routes.py b/tests/test_routes.py index 8098157..d3e9085 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -12,12 +12,14 @@ from service.common import status # HTTP Status Codes from service.models import db, Account, init_db from service.routes import app +from service import talisman DATABASE_URI = os.getenv( "DATABASE_URI", "postgresql://postgres:postgres@localhost:5432/postgres" ) BASE_URL = "/accounts" +HTTPS_ENVIRON = {'wsgi.url_scheme': 'https'} ###################################################################### @@ -34,6 +36,7 @@ def setUpClass(cls): app.config["SQLALCHEMY_DATABASE_URI"] = DATABASE_URI app.logger.setLevel(logging.CRITICAL) init_db(app) + talisman.force_https = False @classmethod def tearDownClass(cls): @@ -172,4 +175,24 @@ def test_delete_account(self): def test_method_not_allowed(self): """It should not allow an illegal method call""" resp = self.client.delete(BASE_URL) - self.assertEqual(resp.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) \ No newline at end of file + self.assertEqual(resp.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) + + def test_security_headers(self): + """It should return security headers""" + response = self.client.get('/', environ_overrides=HTTPS_ENVIRON) + self.assertEqual(response.status_code, status.HTTP_200_OK) + headers = { + 'X-Frame-Options': 'SAMEORIGIN', + 'X-Content-Type-Options': 'nosniff', + 'Content-Security-Policy': 'default-src \'self\'', + 'Referrer-Policy': 'strict-origin-when-cross-origin' + } + for key, value in headers.items(): + self.assertEqual(response.headers.get(key), value) + + def test_cors_security(self): + """It should return a CORS header""" + response = self.client.get('/', environ_overrides=HTTPS_ENVIRON) + self.assertEqual(response.status_code, status.HTTP_200_OK) + # Check for the CORS header + self.assertEqual(response.headers.get('Access-Control-Allow-Origin'), '*') \ No newline at end of file