Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 39 additions & 75 deletions todo/tests/unit/views/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from rest_framework.test import APISimpleTestCase, APIClient, APIRequestFactory
from rest_framework.test import APITestCase, APIClient, APIRequestFactory
from rest_framework.reverse import reverse
from rest_framework import status
from unittest.mock import patch, Mock, PropertyMock
Expand All @@ -14,7 +14,7 @@
from todo.constants.messages import AppMessages, AuthErrorMessages


class GoogleLoginViewTests(APISimpleTestCase):
class GoogleLoginViewTests(APITestCase):
def setUp(self):
super().setUp()
self.client = APIClient()
Expand Down Expand Up @@ -59,46 +59,54 @@ def test_get_with_redirect_url(self, mock_get_auth_url):
mock_get_auth_url.assert_called_once_with(redirect_url)


class GoogleCallbackViewTests(APISimpleTestCase):
class GoogleCallbackViewTests(APITestCase):
def setUp(self):
super().setUp()
self.client = APIClient()
self.url = reverse("google_callback")
self.factory = APIRequestFactory()
self.view = GoogleCallbackView.as_view()

def test_get_returns_error_for_oauth_error(self):
def test_get_redirects_for_oauth_error(self):
error = "access_denied"
request = self.factory.get(f"{self.url}?error={error}")
response = self.client.get(f"{self.url}?error={error}")

response = self.view(request)
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
self.assertIn("error=access_denied", response.url)

self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data["message"], error)
self.assertEqual(response.data["errors"][0]["detail"], error)
def test_get_redirects_for_missing_code(self):
response = self.client.get(self.url)

self.assertEqual(response.status_code, status.HTTP_302_FOUND)
self.assertIn("error=missing_parameters", response.url)

def test_get_redirects_for_valid_code_and_state(self):
response = self.client.get(f"{self.url}?code=test_code&state=test_state")

def test_get_returns_error_for_missing_code(self):
request = self.factory.get(self.url)
self.assertEqual(response.status_code, status.HTTP_302_FOUND)
self.assertIn("code=test_code", response.url)
self.assertIn("state=test_state", response.url)

response = self.view(request)
def test_post_returns_error_for_missing_code(self):
response = self.client.post(self.url, {})

self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data["message"], "No authorization code received from Google")
self.assertEqual(response.data["errors"][0]["detail"], "No authorization code received from Google")

def test_get_returns_error_for_invalid_state(self):
request = self.factory.get(f"{self.url}?code=test_code&state=invalid_state")
request.session = {"oauth_state": "different_state"}
def test_post_returns_error_for_invalid_state(self):

session = self.client.session
session["oauth_state"] = "different_state"
session.save()

response = self.view(request)
response = self.client.post(self.url, {"code": "test_code", "state": "invalid_state"})

self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data["message"], "Invalid state parameter")
self.assertEqual(response.data["errors"][0]["detail"], "Invalid state parameter")

@patch("todo.services.google_oauth_service.GoogleOAuthService.handle_callback")
@patch("todo.services.user_service.UserService.create_or_update_user")
def test_get_handles_callback_successfully(self, mock_create_user, mock_handle_callback):
def test_post_handles_callback_successfully(self, mock_create_user, mock_handle_callback):
mock_google_data = {
"id": "test_google_id",
"email": "[email protected]",
Expand All @@ -115,70 +123,26 @@ def test_get_handles_callback_successfully(self, mock_create_user, mock_handle_c
mock_handle_callback.return_value = mock_google_data
mock_create_user.return_value = mock_user

request = self.factory.get(f"{self.url}?code=test_code&state=test_state")
request.session = {"oauth_state": "test_state"}
session = self.client.session
session["oauth_state"] = "test_state"
session.save()

response = self.view(request)
response = self.client.post(self.url, {"code": "test_code", "state": "test_state"})

self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn("✅ Google OAuth Login Successful!", response.content.decode())
self.assertIn(str(mock_user.id), response.content.decode())
self.assertIn(mock_user.name, response.content.decode())
self.assertIn(mock_user.email_id, response.content.decode())
self.assertIn(mock_user.google_id, response.content.decode())
self.assertEqual(response.data["data"]["user"]["id"], user_id)
self.assertEqual(response.data["data"]["user"]["name"], mock_user.name)
self.assertEqual(response.data["data"]["user"]["email"], mock_user.email_id)
self.assertEqual(response.data["data"]["user"]["google_id"], mock_user.google_id)
self.assertIn("ext-access", response.cookies)
self.assertIn("ext-refresh", response.cookies)
self.assertNotIn("oauth_state", request.session)
self.assertNotIn("oauth_state", self.client.session)


class GoogleAuthStatusViewTests(APISimpleTestCase):
def setUp(self):
super().setUp()
self.client = APIClient()
self.url = reverse("google_status")

def test_get_returns_401_when_no_access_token(self):
response = self.client.get(self.url)

self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response.data["message"], AuthErrorMessages.NO_ACCESS_TOKEN)
self.assertEqual(response.data["authenticated"], False)
self.assertEqual(response.data["statusCode"], status.HTTP_401_UNAUTHORIZED)

@patch("todo.utils.google_jwt_utils.validate_google_access_token")
@patch("todo.services.user_service.UserService.get_user_by_id")
def test_get_returns_user_info_when_authenticated(self, mock_get_user, mock_validate_token):
user_id = str(ObjectId())
user_data = {
"user_id": user_id,
"google_id": "test_google_id",
"email": "[email protected]",
"name": "Test User",
}
mock_validate_token.return_value = user_data

mock_user = Mock()
mock_user.id = ObjectId(user_id)
mock_user.google_id = "test_google_id"
mock_user.email_id = "[email protected]"
mock_user.name = "Test User"
type(mock_user).id = PropertyMock(return_value=ObjectId(user_id))

mock_get_user.return_value = mock_user

tokens = generate_google_token_pair(user_data)
self.client.cookies["ext-access"] = tokens["access_token"]

response = self.client.get(self.url, HTTP_ACCEPT="application/json")

self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["data"]["user"]["id"], user_id)
self.assertEqual(response.data["data"]["user"]["email"], mock_user.email_id)
self.assertEqual(response.data["data"]["user"]["name"], mock_user.name)
self.assertEqual(response.data["data"]["user"]["google_id"], mock_user.google_id)


class GoogleRefreshViewTests(APISimpleTestCase):
class GoogleRefreshViewTests(APITestCase):
def setUp(self):
super().setUp()
self.client = APIClient()
Expand Down Expand Up @@ -213,7 +177,7 @@ def test_get_refreshes_token_successfully(self, mock_validate_token):
self.assertIn("ext-access", response.cookies)


class GoogleLogoutViewTests(APISimpleTestCase):
class GoogleLogoutViewTests(APITestCase):
def setUp(self):
super().setUp()
self.client = APIClient()
Expand All @@ -231,7 +195,7 @@ def test_get_returns_success_and_clears_cookies(self):
self.client.cookies["ext-refresh"] = tokens["refresh_token"]

response = self.client.get(self.url, HTTP_ACCEPT="application/json")

self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data["data"]["success"], True)
self.assertEqual(response.data["message"], AppMessages.GOOGLE_LOGOUT_SUCCESS)
Expand Down
2 changes: 0 additions & 2 deletions todo/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
GoogleCallbackView,
GoogleRefreshView,
GoogleLogoutView,
GoogleAuthStatusView,
)

urlpatterns = [
Expand All @@ -15,7 +14,6 @@
path("health", HealthView.as_view(), name="health"),
path("auth/google/login/", GoogleLoginView.as_view(), name="google_login"),
path("auth/google/callback/", GoogleCallbackView.as_view(), name="google_callback"),
path("auth/google/status/", GoogleAuthStatusView.as_view(), name="google_status"),
path("auth/google/refresh/", GoogleRefreshView.as_view(), name="google_refresh"),
path("auth/google/logout/", GoogleLogoutView.as_view(), name="google_logout"),
]
Loading