diff --git a/tests/test_auth.py b/tests/test_auth.py index 6565475..52e07ab 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,12 +1,17 @@ +import math + import pytest from app.models import User from tests import utils +from flask_jwt_extended import decode_token class TestAuth: TEST_EMAIL = "testuser@example.com" TEST_PASSWORD = "testpassword" + ACCESS_TOKEN_DELTA = 10800 # 3 hours in seconds + REFRESH_TOKEN_DELTA = 259200 # 3 days in seconds @pytest.fixture(autouse=True) def setup(self, client): @@ -41,6 +46,34 @@ def _test_invalid_request_data(self, endpoint, expected_status=400): response = self.client.post(endpoint, data="not json data") assert response.status_code == 415 + def _decode_token(self, token): + # Needs Flask app context for secret/algorithms from current_app.config + with self.client.application.app_context(): + return decode_token(token, allow_expired=False) + def _assert_jwt_structure(self, token, expected_sub, expected_type, fresh=False): + assert token.count(".") == 2, f"Token does not have three segments: {token}" + payload = self._decode_token(token) + assert payload["sub"] == expected_sub + assert payload["type"] == expected_type + assert "iat" in payload + assert "exp" in payload + assert "jti" in payload + assert payload["fresh"] is fresh + + # Expiry check + expected_delta = None + if expected_type == "access": + expected_delta = self.ACCESS_TOKEN_DELTA + elif expected_type == "refresh": + expected_delta = self.REFRESH_TOKEN_DELTA + + if expected_delta is not None: + actual_delta = payload["exp"] - payload["iat"] + # Allow a small margin (e.g., 0-2 seconds) for processing time + assert math.isclose(actual_delta, expected_delta, abs_tol=2), ( + f"Token expiry delta {actual_delta} != expected {expected_delta}" + ) + def test_register_success(self, register_user): response = register_user(self.TEST_EMAIL, self.TEST_PASSWORD) @@ -81,8 +114,10 @@ def test_login_success(self, register_user, login_user): data = response.get_json() assert "access_token" in data assert "refresh_token" in data - assert len(data["access_token"]) > 0 - assert len(data["refresh_token"]) > 0 + + user = self._verify_user_in_db(self.TEST_EMAIL) + self._assert_jwt_structure(data["access_token"], expected_sub=str(user.id), expected_type="access", fresh=True) + self._assert_jwt_structure(data["refresh_token"], expected_sub=str(user.id), expected_type="refresh") def test_login_invalid_password(self, register_user, login_user): register_user(self.TEST_EMAIL, self.TEST_PASSWORD) @@ -109,6 +144,9 @@ def test_refresh_token(self, register_user, login_user): assert data["access_token"] != original_access_token assert "refresh_token" not in data + user = self._verify_user_in_db(self.TEST_EMAIL) + self._assert_jwt_structure(data["access_token"], expected_sub=str(user.id), expected_type="access") + def test_refresh_token_invalid(self, register_user, login_user): # Access token test register_user(self.TEST_EMAIL, self.TEST_PASSWORD)