diff --git a/main.py b/main.py index 9ec7448..ec8fe93 100644 --- a/main.py +++ b/main.py @@ -8,8 +8,16 @@ from fastapi.exceptions import RequestValidationError, HTTPException, StarletteHTTPException from sqlmodel import Session from routers import authentication, organization, role, user -from utils.auth import get_user_with_relations, get_optional_user, NeedsNewTokens, get_user_from_reset_token, PasswordValidationError, AuthenticationError -from utils.models import User, Organization +from utils.auth import ( + HTML_PASSWORD_PATTERN, + get_user_with_relations, + get_optional_user, + NeedsNewTokens, + get_user_from_reset_token, + PasswordValidationError, + AuthenticationError +) +from utils.models import User from utils.db import get_session, set_up_db from utils.images import MAX_FILE_SIZE, MIN_DIMENSION, MAX_DIMENSION, ALLOWED_CONTENT_TYPES @@ -174,6 +182,8 @@ async def read_register( ): if params["user"]: return RedirectResponse(url="/dashboard", status_code=302) + + params["password_pattern"] = HTML_PASSWORD_PATTERN return templates.TemplateResponse(params["request"], "authentication/register.html", params) @@ -219,6 +229,7 @@ async def read_reset_password( params["email"] = email params["token"] = token + params["password_pattern"] = HTML_PASSWORD_PATTERN return templates.TemplateResponse(params["request"], "authentication/reset_password.html", params) diff --git a/templates/authentication/register.html b/templates/authentication/register.html index ea84753..6756ada 100644 --- a/templates/authentication/register.html +++ b/templates/authentication/register.html @@ -24,9 +24,8 @@
- diff --git a/templates/authentication/reset_password.html b/templates/authentication/reset_password.html index 9dd70ac..bf6884b 100644 --- a/templates/authentication/reset_password.html +++ b/templates/authentication/reset_password.html @@ -17,7 +17,7 @@
diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..d3334e9 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,148 @@ +import re +import string +import random +from datetime import timedelta +from urllib.parse import urlparse, parse_qs +from starlette.datastructures import URLPath +from main import app +from utils.auth import ( + create_access_token, + create_refresh_token, + verify_password, + get_password_hash, + validate_token, + generate_password_reset_url, + COMPILED_PASSWORD_PATTERN, + convert_python_regex_to_html +) + + +def test_convert_python_regex_to_html(): + PYTHON_SPECIAL_CHARS = r"(?=.*[\[\]\\@$!%*?&{}<>.,'#\-_=+\(\):;|~/\^])" + HTML_EQUIVALENT = r"(?=.*[\[\]\\@$!%*?&\{\}\<\>\.\,\\'#\-_=\+\(\):;\|~\/\^])" + + PYTHON_SPECIAL_CHARS = convert_python_regex_to_html(PYTHON_SPECIAL_CHARS) + + assert PYTHON_SPECIAL_CHARS == HTML_EQUIVALENT + + +def test_password_hashing(): + password = "Test123!@#" + hashed = get_password_hash(password) + assert verify_password(password, hashed) + assert not verify_password("wrong_password", hashed) + + +def test_token_creation_and_validation(): + data = {"sub": "test@example.com"} + + # Test access token + access_token = create_access_token(data) + decoded = validate_token(access_token, "access") + assert decoded is not None + assert decoded["sub"] == data["sub"] + assert decoded["type"] == "access" + + # Test refresh token + refresh_token = create_refresh_token(data) + decoded = validate_token(refresh_token, "refresh") + assert decoded is not None + assert decoded["sub"] == data["sub"] + assert decoded["type"] == "refresh" + + +def test_expired_token(): + data = {"sub": "test@example.com"} + expired_delta = timedelta(minutes=-10) + expired_token = create_access_token(data, expired_delta) + decoded = validate_token(expired_token, "access") + assert decoded is None + + +def test_invalid_token_type(): + data = {"sub": "test@example.com"} + access_token = create_access_token(data) + decoded = validate_token(access_token, "refresh") + assert decoded is None + +def test_password_reset_url_generation(): + """ + Tests that the password reset URL is correctly formatted and contains + the required query parameters. + """ + test_email = "test@example.com" + test_token = "abc123" + + url = generate_password_reset_url(test_email, test_token) + + # Parse the URL + parsed = urlparse(url) + query_params = parse_qs(parsed.query) + + # Get the actual path from the FastAPI app + reset_password_path: URLPath = app.url_path_for("reset_password") + + # Verify URL path + assert parsed.path == str(reset_password_path) + + # Verify query parameters + assert "email" in query_params + assert "token" in query_params + assert query_params["email"][0] == test_email + assert query_params["token"][0] == test_token + +def test_password_pattern(): + """ + Tests that the password pattern is correctly defined. to require at least + one uppercase letter, one lowercase letter, one digit, and one special + character, and at least 8 characters long. Allowed special characters are: + !@#$%^&*()_+-=[]{}|;:,.<>? + """ + special_characters = "!@#$%^&*()_+-=[]{}|;:,.<>?" + uppercase_letters = string.ascii_uppercase + lowercase_letters = string.ascii_lowercase + digits = string.digits + + required_elements = { + "special": special_characters, + "uppercase": uppercase_letters, + "lowercase": lowercase_letters, + "digit": digits + } + + # Valid password tests + for element in required_elements: + for c in required_elements[element]: + password = c + "test" + for other_element in required_elements: + if other_element != element: + password += random.choice(required_elements[other_element]) + # Randomize the order of the characters in the string + password = ''.join(random.sample(password, len(password))) + assert re.match(COMPILED_PASSWORD_PATTERN, password) is not None, f"Password {password} does not match the pattern" + + # Invalid password tests + + # Empty password + password = "" + assert re.match(COMPILED_PASSWORD_PATTERN, password) is None + + # Too short + password = "aA1!aA1" + assert re.match(COMPILED_PASSWORD_PATTERN, password) is None + + # No uppercase letter + password = "a1!" * 3 + assert re.match(COMPILED_PASSWORD_PATTERN, password) is None + + # No lowercase letter + password = "A1!" * 3 + assert re.match(COMPILED_PASSWORD_PATTERN, password) is None + + # No digit + password = "aA!" * 3 + assert re.match(COMPILED_PASSWORD_PATTERN, password) is None + + # No special character + password = "aA1" * 3 + assert re.match(COMPILED_PASSWORD_PATTERN, password) is None diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 3f76c79..4207186 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -12,11 +12,8 @@ from utils.models import User, PasswordResetToken from utils.auth import ( create_access_token, - create_refresh_token, verify_password, - get_password_hash, validate_token, - generate_password_reset_url ) from .conftest import SetupError @@ -41,49 +38,6 @@ def mock_resend_send(mock_email_response): yield mock -# --- Authentication Helper Function Tests --- - - -def test_password_hashing(): - password = "Test123!@#" - hashed = get_password_hash(password) - assert verify_password(password, hashed) - assert not verify_password("wrong_password", hashed) - - -def test_token_creation_and_validation(): - data = {"sub": "test@example.com"} - - # Test access token - access_token = create_access_token(data) - decoded = validate_token(access_token, "access") - assert decoded is not None - assert decoded["sub"] == data["sub"] - assert decoded["type"] == "access" - - # Test refresh token - refresh_token = create_refresh_token(data) - decoded = validate_token(refresh_token, "refresh") - assert decoded is not None - assert decoded["sub"] == data["sub"] - assert decoded["type"] == "refresh" - - -def test_expired_token(): - data = {"sub": "test@example.com"} - expired_delta = timedelta(minutes=-10) - expired_token = create_access_token(data, expired_delta) - decoded = validate_token(expired_token, "access") - assert decoded is None - - -def test_invalid_token_type(): - data = {"sub": "test@example.com"} - access_token = create_access_token(data) - decoded = validate_token(access_token, "refresh") - assert decoded is None - - # --- API Endpoint Tests --- @@ -272,33 +226,6 @@ def test_password_reset_with_invalid_token(unauth_client: TestClient, test_user: assert response.status_code == 400 -def test_password_reset_url_generation(unauth_client: TestClient): - """ - Tests that the password reset URL is correctly formatted and contains - the required query parameters. - """ - test_email = "test@example.com" - test_token = "abc123" - - url = generate_password_reset_url(test_email, test_token) - - # Parse the URL - parsed = urlparse(url) - query_params = parse_qs(parsed.query) - - # Get the actual path from the FastAPI app - reset_password_path: URLPath = app.url_path_for("reset_password") - - # Verify URL path - assert parsed.path == str(reset_password_path) - - # Verify query parameters - assert "email" in query_params - assert "token" in query_params - assert query_params["email"][0] == test_email - assert query_params["token"][0] == test_token - - def test_password_reset_email_url(unauth_client: TestClient, session: Session, test_user: User, mock_resend_send): """ Tests that the password reset email contains a properly formatted reset URL. diff --git a/utils/auth.py b/utils/auth.py index 01fc8f4..d809fc7 100644 --- a/utils/auth.py +++ b/utils/auth.py @@ -34,6 +34,54 @@ ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 30 REFRESH_TOKEN_EXPIRE_DAYS = 30 +PASSWORD_PATTERN_COMPONENTS = [ + r"(?=.*\d)", # At least one digit + r"(?=.*[a-z])", # At least one lowercase letter + r"(?=.*[A-Z])", # At least one uppercase letter + r"(?=.*[\[\]\\@$!%*?&{}<>.,'#\-_=+\(\):;|~/\^])", # At least one special character + r".{8,}" # At least 8 characters long +] +COMPILED_PASSWORD_PATTERN = re.compile(r"".join(PASSWORD_PATTERN_COMPONENTS)) + + +def convert_python_regex_to_html(regex: str) -> str: + """ + Replace each special character with its escaped version only when inside character classes. + Ensures that the single quote "'" is doubly escaped. + """ + # Map each special char to its escaped form + special_map = { + '{': r'\{', + '}': r'\}', + '<': r'\<', + '>': r'\>', + '.': r'\.', + '+': r'\+', + '|': r'\|', + ',': r'\,', + "'": r"\\'", # doubly escaped single quote + "/": r"\/", + } + + # Regex to match the entire character class [ ... ] + pattern = r"\[((?:\\.|[^\]])*)\]" + + def replacer(match: re.Match) -> str: + """ + For the matched character class, replace all special characters inside it. + """ + inside = match.group(1) # the contents inside [ ... ] + for ch, escaped in special_map.items(): + inside = inside.replace(ch, escaped) + return f"[{inside}]" + + # Use re.sub with a function to ensure we only replace inside the character class + return re.sub(pattern, replacer, regex) + + +HTML_PASSWORD_PATTERN = "".join( + convert_python_regex_to_html(component) for component in PASSWORD_PATTERN_COMPONENTS +) # --- Custom Exceptions --- @@ -105,11 +153,8 @@ def validate_password_strength(v: str) -> str: - At least 8 characters long """ logger.debug(f"Validating password for {field_name}") - pattern = re.compile( - r"(?=.*\d)(?=.*[a-z])(?=.*[A-Z])(?=.*[@$!%*?&{}<>.,\\'#\-_=+\(\)\[\]:;|~/])[A-Za-z\d@$!%*?&{}<>.,\\'#\-_=+\(\)\[\]:;|~/]{8,}") - if not pattern.match(v): - logger.debug(f"Password for { - field_name} does not satisfy the security policy") + if not COMPILED_PASSWORD_PATTERN.match(v): + logger.debug(f"Password for {field_name} does not satisfy the security policy") raise PasswordValidationError( field=field_name, message=f"{field_name} does not satisfy the security policy"