Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@ jobs:
runs-on: ubuntu-latest
if: ${{ !contains(github.event.pull_request.title, '[skip tests]') }}

env:
MONGODB_URI: mongodb://db:27017
DB_NAME: todo-app
GOOGLE_JWT_SECRET_KEY: "test-secret-key-for-jwt"
GOOGLE_JWT_ACCESS_LIFETIME: "3600"
GOOGLE_JWT_REFRESH_LIFETIME: "604800"
GOOGLE_OAUTH_CLIENT_ID: "test-client-id"
GOOGLE_OAUTH_CLIENT_SECRET: "test-client-secret"
GOOGLE_OAUTH_REDIRECT_URI: "http://localhost:3000/auth/callback"
COOKIE_SECURE: "False"
COOKIE_SAMESITE: "Lax"

steps:
- name: Checkout code
uses: actions/checkout@v3
Expand All @@ -17,6 +29,7 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: "3.11.*"
python-version: "3.11.*"

- name: Install dependencies
run: |
Expand Down
46 changes: 26 additions & 20 deletions todo/middlewares/jwt_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,16 @@ def __call__(self, request):
error_response = ApiErrorResponse(
statusCode=status.HTTP_401_UNAUTHORIZED,
message=AuthErrorMessages.AUTHENTICATION_REQUIRED,
errors=[ApiErrorDetail(
title=ApiErrors.AUTHENTICATION_FAILED.format(""),
detail=AuthErrorMessages.AUTHENTICATION_REQUIRED
)],
errors=[
ApiErrorDetail(
title=ApiErrors.AUTHENTICATION_FAILED.format(""),
detail=AuthErrorMessages.AUTHENTICATION_REQUIRED,
)
],
)
return JsonResponse(
data=error_response.model_dump(mode="json", exclude_none=True), status=status.HTTP_401_UNAUTHORIZED
)
return JsonResponse(data=error_response.model_dump(mode="json", exclude_none=True), status=status.HTTP_401_UNAUTHORIZED)

except (TokenMissingError, TokenExpiredError, TokenInvalidError) as e:
return self._handle_rds_auth_error(e)
Expand All @@ -45,12 +49,16 @@ def __call__(self, request):
error_response = ApiErrorResponse(
statusCode=status.HTTP_401_UNAUTHORIZED,
message=ApiErrors.AUTHENTICATION_FAILED.format(""),
errors=[ApiErrorDetail(
title=ApiErrors.AUTHENTICATION_FAILED.format(""),
detail=AuthErrorMessages.AUTHENTICATION_REQUIRED
)],
errors=[
ApiErrorDetail(
title=ApiErrors.AUTHENTICATION_FAILED.format(""),
detail=AuthErrorMessages.AUTHENTICATION_REQUIRED,
)
],
)
return JsonResponse(
data=error_response.model_dump(mode="json", exclude_none=True), status=status.HTTP_401_UNAUTHORIZED
)
return JsonResponse(data=error_response.model_dump(mode="json", exclude_none=True), status=status.HTTP_401_UNAUTHORIZED)

def _try_authentication(self, request) -> bool:
if self._try_google_auth(request):
Expand Down Expand Up @@ -111,23 +119,21 @@ def _handle_rds_auth_error(self, exception):
error_response = ApiErrorResponse(
statusCode=status.HTTP_401_UNAUTHORIZED,
message=str(exception),
errors=[ApiErrorDetail(
title=ApiErrors.AUTHENTICATION_FAILED.format(""),
detail=str(exception)
)],
errors=[ApiErrorDetail(title=ApiErrors.AUTHENTICATION_FAILED.format(""), detail=str(exception))],
)
return JsonResponse(
data=error_response.model_dump(mode="json", exclude_none=True), status=status.HTTP_401_UNAUTHORIZED
)
return JsonResponse(data=error_response.model_dump(mode="json", exclude_none=True), status=status.HTTP_401_UNAUTHORIZED)

def _handle_google_auth_error(self, exception):
error_response = ApiErrorResponse(
statusCode=status.HTTP_401_UNAUTHORIZED,
message=str(exception),
errors=[ApiErrorDetail(
title=ApiErrors.AUTHENTICATION_FAILED.format(""),
detail=str(exception)
)],
errors=[ApiErrorDetail(title=ApiErrors.AUTHENTICATION_FAILED.format(""), detail=str(exception))],
)
return JsonResponse(
data=error_response.model_dump(mode="json", exclude_none=True), status=status.HTTP_401_UNAUTHORIZED
)
return JsonResponse(data=error_response.model_dump(mode="json", exclude_none=True), status=status.HTTP_401_UNAUTHORIZED)


def is_google_user(request) -> bool:
Expand Down
18 changes: 18 additions & 0 deletions todo/tests/fixtures/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from datetime import datetime, timezone

users_db_data = [
{
"google_id": "123456789",
"email_id": "[email protected]",
"name": "Test User",
"created_at": datetime.now(timezone.utc),
"updated_at": datetime.now(timezone.utc),
},
{
"google_id": "987654321",
"email_id": "[email protected]",
"name": "Another User",
"created_at": datetime.now(timezone.utc),
"updated_at": datetime.now(timezone.utc),
},
]
25 changes: 21 additions & 4 deletions todo/tests/integration/test_task_detail_api.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
from http import HTTPStatus
from bson import ObjectId
from django.urls import reverse
from rest_framework.test import APIClient
from bson import ObjectId

from todo.tests.fixtures.task import tasks_db_data
from todo.tests.integration.base_mongo_test import BaseMongoTestCase
from todo.constants.messages import ApiErrors, ValidationErrors
from todo.utils.google_jwt_utils import generate_google_token_pair


class AuthenticatedMongoTestCase(BaseMongoTestCase):
def setUp(self):
super().setUp()
self._setup_auth_cookies()

def _setup_auth_cookies(self):
user_data = {
"user_id": str(ObjectId()),
"google_id": "test_google_id",
"email": "[email protected]",
"name": "Test User",
}
tokens = generate_google_token_pair(user_data)
self.client.cookies["ext-access"] = tokens["access_token"]
self.client.cookies["ext-refresh"] = tokens["refresh_token"]


class TaskDetailAPIIntegrationTest(BaseMongoTestCase):
class TaskDetailAPIIntegrationTest(AuthenticatedMongoTestCase):
def setUp(self):
super().setUp()
self.db.tasks.delete_many({}) # Clear tasks to avoid DuplicateKeyError
Expand All @@ -17,7 +35,6 @@ def setUp(self):
self.existing_task_id = str(self.task_doc["_id"])
self.non_existent_id = str(ObjectId())
self.invalid_task_id = "invalid-task-id"
self.client = APIClient()

def test_get_task_by_id_success(self):
url = reverse("task_detail", args=[self.existing_task_id])
Expand Down
25 changes: 21 additions & 4 deletions todo/tests/integration/test_tasks_delete.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
from http import HTTPStatus
from bson import ObjectId
from django.urls import reverse
from rest_framework.test import APIClient
from bson import ObjectId

from todo.tests.fixtures.task import tasks_db_data
from todo.tests.integration.base_mongo_test import BaseMongoTestCase
from todo.constants.messages import ValidationErrors, ApiErrors
from todo.utils.google_jwt_utils import generate_google_token_pair


class AuthenticatedMongoTestCase(BaseMongoTestCase):
def setUp(self):
super().setUp()
self._setup_auth_cookies()

def _setup_auth_cookies(self):
user_data = {
"user_id": str(ObjectId()),
"google_id": "test_google_id",
"email": "[email protected]",
"name": "Test User",
}
tokens = generate_google_token_pair(user_data)
self.client.cookies["ext-access"] = tokens["access_token"]
self.client.cookies["ext-refresh"] = tokens["refresh_token"]


class TaskDeleteAPIIntegrationTest(BaseMongoTestCase):
class TaskDeleteAPIIntegrationTest(AuthenticatedMongoTestCase):
def setUp(self):
super().setUp()
self.db.tasks.delete_many({})
Expand All @@ -17,7 +35,6 @@ def setUp(self):
self.existing_task_id = str(task_doc["_id"])
self.non_existent_id = str(ObjectId())
self.invalid_task_id = "invalid-task-id"
self.client = APIClient()

def test_delete_task_success(self):
url = reverse("task_detail", args=[self.existing_task_id])
Expand Down
39 changes: 20 additions & 19 deletions todo/tests/unit/exceptions/test_exception_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,24 @@
class ExceptionHandlerTests(TestCase):
@patch("todo.exceptions.exception_handler.format_validation_errors")
def test_returns_400_for_validation_error(self, mock_format_validation_errors: Mock):
validation_error = DRFValidationError(detail={"field": ["error message"]})
mock_format_validation_errors.return_value = [
ApiErrorDetail(detail="error message", source={ApiErrorSource.PARAMETER: "field"})
]

response = handle_exception(validation_error, {})

self.assertIsInstance(response, Response)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
expected_response = {
"statusCode": 400,
"message": "Invalid request",
"errors": [{"source": {"parameter": "field"}, "detail": "error message"}],
}
self.assertDictEqual(response.data, expected_response)

mock_format_validation_errors.assert_called_once_with(validation_error.detail)
error_detail = {"field": ["error message"]}
exception = DRFValidationError(detail=error_detail)
request = Mock()

with patch("todo.exceptions.exception_handler.format_validation_errors") as mock_format:
mock_format.return_value = [
ApiErrorDetail(detail="error message", source={ApiErrorSource.PARAMETER: "field"})
]
response = handle_exception(exception, {"request": request})

self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
expected_response = {
"statusCode": 400,
"message": "error message",
"errors": [{"source": {"parameter": "field"}, "detail": "error message"}],
}
self.assertDictEqual(response.data, expected_response)
mock_format.assert_called_once_with(error_detail)

def test_custom_handler_formats_generic_exception(self):
request = None
Expand All @@ -51,9 +52,9 @@ def test_custom_handler_formats_generic_exception(self):

expected_detail_obj_in_list = ApiErrorDetail(
detail=error_message if settings.DEBUG else ApiErrors.INTERNAL_SERVER_ERROR,
title=ApiErrors.UNEXPECTED_ERROR_OCCURRED,
title=error_message,
)
expected_main_message = ApiErrors.UNEXPECTED_ERROR_OCCURRED
expected_main_message = ApiErrors.INTERNAL_SERVER_ERROR

self.assertEqual(response.data.get("statusCode"), status.HTTP_500_INTERNAL_SERVER_ERROR)
self.assertEqual(response.data.get("message"), expected_main_message)
Expand Down
1 change: 1 addition & 0 deletions todo/tests/unit/middlewares/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# This file is required for Python to recognize this directory as a package
130 changes: 130 additions & 0 deletions todo/tests/unit/middlewares/test_jwt_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from unittest import TestCase
from unittest.mock import Mock, patch
from django.http import HttpRequest, JsonResponse
from django.conf import settings
from rest_framework import status
import json

from todo.middlewares.jwt_auth import JWTAuthenticationMiddleware, is_google_user, is_rds_user, get_current_user_info
from todo.constants.messages import AuthErrorMessages


class JWTAuthenticationMiddlewareTests(TestCase):
def setUp(self):
self.get_response = Mock(return_value=JsonResponse({"data": "test"}))
self.middleware = JWTAuthenticationMiddleware(self.get_response)
self.request = Mock(spec=HttpRequest)
self.request.path = "/v1/tasks"
self.request.headers = {}
self.request.COOKIES = {}
settings.PUBLIC_PATHS = ["/v1/auth/google/login"]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works, but the best practice is to use Django's reverse() utility. Hardcoding URL paths makes our tests brittle. If we ever decide to refactor our API routes (e.g., change /v1/ to /v2/), we would have to manually find and replace these strings in every single test file, which is error-prone.


def test_public_path_authentication_bypass(self):
"""Test that requests to public paths bypass authentication"""
self.request.path = "/v1/auth/google/login"
response = self.middleware(self.request)
self.get_response.assert_called_once_with(self.request)
self.assertEqual(response.status_code, 200)

@patch("todo.middlewares.jwt_auth.JWTAuthenticationMiddleware._try_google_auth")
def test_google_auth_success(self, mock_google_auth):
"""Test successful Google authentication"""
mock_google_auth.return_value = True
self.request.COOKIES = {"ext-access": "google_token"}
response = self.middleware(self.request)
mock_google_auth.assert_called_once_with(self.request)
self.get_response.assert_called_once_with(self.request)
self.assertEqual(response.status_code, 200)

@patch("todo.middlewares.jwt_auth.JWTAuthenticationMiddleware._try_rds_auth")
def test_rds_auth_success(self, mock_rds_auth):
"""Test successful RDS authentication"""
mock_rds_auth.return_value = True
self.request.COOKIES = {"rds_session_v2": "valid_token"}
response = self.middleware(self.request)
mock_rds_auth.assert_called_once_with(self.request)
self.get_response.assert_called_once_with(self.request)
self.assertEqual(response.status_code, 200)

@patch("todo.middlewares.jwt_auth.JWTAuthenticationMiddleware._try_google_auth")
def test_google_token_expired(self, mock_google_auth):
"""Test handling of expired Google token"""
mock_google_auth.return_value = False
self.request.COOKIES = {"ext-access": "expired_token"}
response = self.middleware(self.request)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
response_data = json.loads(response.content)
self.assertEqual(response_data["message"], AuthErrorMessages.AUTHENTICATION_REQUIRED)

@patch("todo.middlewares.jwt_auth.JWTAuthenticationMiddleware._try_rds_auth")
def test_rds_token_invalid(self, mock_rds_auth):
"""Test handling of invalid RDS token"""
mock_rds_auth.return_value = False
self.request.COOKIES = {"rds_session_v2": "invalid_token"}
response = self.middleware(self.request)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
response_data = json.loads(response.content)
self.assertEqual(response_data["message"], AuthErrorMessages.AUTHENTICATION_REQUIRED)

def test_no_tokens_provided(self):
"""Test handling of request with no tokens"""
response = self.middleware(self.request)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
response_data = json.loads(response.content)
self.assertEqual(response_data["message"], AuthErrorMessages.AUTHENTICATION_REQUIRED)


class AuthUtilityFunctionsTests(TestCase):
def setUp(self):
self.request = Mock(spec=HttpRequest)

def test_is_google_user(self):
"""Test checking if request is from Google user"""
self.request.auth_type = "google"
self.assertTrue(is_google_user(self.request))

self.request.auth_type = None
self.assertFalse(is_google_user(self.request))

self.request.auth_type = "rds"
self.assertFalse(is_google_user(self.request))

def test_is_rds_user(self):
"""Test checking if request is from RDS user"""
self.request.auth_type = "rds"
self.assertTrue(is_rds_user(self.request))

self.request.auth_type = None
self.assertFalse(is_rds_user(self.request))

self.request.auth_type = "google"
self.assertFalse(is_rds_user(self.request))

def test_get_current_user_info_google(self):
"""Test getting user info for Google user"""
self.request.user_id = "google_user_123"
self.request.auth_type = "google"
self.request.google_id = "google_123"
self.request.user_email = "[email protected]"
self.request.user_name = "Test User"
user_info = get_current_user_info(self.request)
self.assertEqual(user_info["user_id"], "google_user_123")
self.assertEqual(user_info["auth_type"], "google")
self.assertEqual(user_info["google_id"], "google_123")
self.assertEqual(user_info["email"], "[email protected]")
self.assertEqual(user_info["name"], "Test User")

def test_get_current_user_info_rds(self):
"""Test getting user info for RDS user"""
self.request.user_id = "rds_user_123"
self.request.auth_type = "rds"
self.request.user_role = "admin"
user_info = get_current_user_info(self.request)
self.assertEqual(user_info["user_id"], "rds_user_123")
self.assertEqual(user_info["auth_type"], "rds")
self.assertEqual(user_info["role"], "admin")

def test_get_current_user_info_no_user_id(self):
"""Test getting user info when no user ID is present"""
user_info = get_current_user_info(self.request)
self.assertIsNone(user_info)
Loading