diff --git a/courses/views/v1/views_test.py b/courses/views/v1/views_test.py index 7b84128458..3be4e508b6 100644 --- a/courses/views/v1/views_test.py +++ b/courses/views/v1/views_test.py @@ -4,16 +4,20 @@ # pylint: disable=unused-argument, redefined-outer-name, too-many-arguments import operator as op +from datetime import datetime, timezone +from unittest.mock import MagicMock, patch from urllib.parse import quote import pytest import reversion from django.db.models import Count, Q +from django.test import RequestFactory from django.test.client import Client from django.urls import reverse from requests import ConnectionError as RequestsConnectionError from requests import HTTPError from rest_framework import status +from rest_framework.test import APIClient from reversion.models import Version from courses.constants import ENROLL_CHANGE_STATUS_UNENROLLED @@ -24,7 +28,9 @@ CourseRunFactory, ) from courses.models import ( + Course, CourseRun, + Program, ProgramEnrollment, ) from courses.serializers.v1.courses import ( @@ -38,7 +44,11 @@ num_queries_from_course, num_queries_from_programs, ) -from courses.views.v1 import UserEnrollmentsApiViewSet +from courses.views.v1 import ( + CourseFilterSet, + ProgramViewSet, + UserEnrollmentsApiViewSet, +) from ecommerce.factories import LineFactory, OrderFactory, ProductFactory from ecommerce.models import Order, OrderStatus from main import features @@ -51,6 +61,7 @@ from main.test_utils import assert_drf_json_equal, duplicate_queries_check from main.utils import encode_json_cookie_value from openedx.exceptions import NoEdxApiAuthError +from users.factories import UserFactory pytestmark = [pytest.mark.django_db] @@ -808,3 +819,159 @@ def test_create_enrollments_with_existing_fulfilled_order( else: assert Order.objects.filter(state=OrderStatus.PENDING).count() == 1 patched_create_enrollments.assert_called_once() + + +@pytest.mark.django_db +class TestCourseFilterSet: + """Test CourseFilterSet filtering methods""" + + def test_filter_courserun_is_enrollable_true(self): + """Test filtering for enrollable courses""" + # Create courses with different enrollment status + enrollable_course = CourseFactory.create(live=True) + # Create an enrollable course run: live=True, has enrollment_start in past, no enrollment_end, has start_date + CourseRunFactory.create( + course=enrollable_course, + live=True, + enrollment_start=datetime(2020, 1, 1, tzinfo=timezone.utc), + enrollment_end=None, + start_date=datetime(2020, 1, 15, tzinfo=timezone.utc), + ) + + non_enrollable_course = CourseFactory.create(live=True) + CourseRunFactory.create(course=non_enrollable_course, live=False) + + queryset = Course.objects.all() + filterset = CourseFilterSet() + + # Test filtering for enrollable courses + result = filterset.filter_courserun_is_enrollable(queryset, None, value=True) + assert enrollable_course in result + + def test_filter_courserun_is_enrollable_false(self): + """Test filtering for non-enrollable courses""" + # Create courses with different enrollment status + enrollable_course = CourseFactory.create(live=True) + # Create an enrollable course run: live=True, has enrollment_start in past, no enrollment_end, has start_date + CourseRunFactory.create( + course=enrollable_course, + live=True, + enrollment_start=datetime(2020, 1, 1, tzinfo=timezone.utc), + enrollment_end=None, + start_date=datetime(2020, 1, 15, tzinfo=timezone.utc), + ) + + non_enrollable_course = CourseFactory.create(live=True) + CourseRunFactory.create(course=non_enrollable_course, live=False) + + queryset = Course.objects.all() + filterset = CourseFilterSet() + + # Test filtering for non-enrollable courses + result = filterset.filter_courserun_is_enrollable(queryset, None, value=False) + assert non_enrollable_course in result + + +@pytest.mark.django_db +class TestProgramViewSetPagination: + """Test ProgramViewSet pagination edge cases""" + + def setup_method(self): + self.factory = RequestFactory() + self.user = UserFactory.create() + self.viewset = ProgramViewSet() + + def test_paginate_queryset_no_page_param(self): + """Test pagination when no page parameter is provided""" + # Create a mock request without page parameter + request = self.factory.get("/api/v1/programs/") + request.user = self.user + request.query_params = {} + + # Set up the viewset + self.viewset.request = request + + # Create a mock pagination class instance + mock_pagination_class = MagicMock() + # Mock the pagination_class attribute to return our mock + with patch.object(self.viewset, "pagination_class", mock_pagination_class): + queryset = Program.objects.all() + result = self.viewset.paginate_queryset(queryset) + + # Should return None when no page param and paginator exists + assert result is None + + def test_paginate_queryset_with_page_param(self): + """Test pagination when page parameter is provided""" + request = self.factory.get("/api/v1/programs/?page=1") + request.user = self.user + request.query_params = {"page": "1"} + + self.viewset.request = request + + queryset = Program.objects.all() + + # Mock both pagination_class and the parent's paginate_queryset method + mock_pagination_class = MagicMock() + with ( + patch.object(self.viewset, "pagination_class", mock_pagination_class), + patch( + "rest_framework.viewsets.ReadOnlyModelViewSet.paginate_queryset" + ) as mock_super, + ): + mock_super.return_value = "paginated_result" + result = self.viewset.paginate_queryset(queryset) + + mock_super.assert_called_once_with(queryset) + assert result == "paginated_result" + + +@pytest.mark.django_db +class TestUserEnrollmentsApiViewSetSync: + """Test UserEnrollmentsApiViewSet sync functionality""" + + def setup_method(self): + self.factory = RequestFactory() + self.user = UserFactory.create() + self.client = APIClient() + self.client.force_authenticate(user=self.user) + self.viewset = UserEnrollmentsApiViewSet() + + @patch("courses.views.v1.sync_enrollments_with_edx") + @patch("courses.views.v1.is_enabled") + def test_list_with_sync_enabled_success(self, mock_is_enabled, mock_sync): + """Test list method when sync is enabled and succeeds""" + mock_is_enabled.return_value = True + mock_sync.return_value = None + + response = self.client.get(reverse("v1:user-enrollments-api-list")) + + assert response.status_code == status.HTTP_200_OK + mock_sync.assert_called_once_with(self.user) + + @patch("courses.views.v1.sync_enrollments_with_edx") + @patch("courses.views.v1.is_enabled") + @patch("courses.views.v1.log.exception") + def test_list_with_sync_enabled_exception( + self, mock_log, mock_is_enabled, mock_sync + ): + """Test list method when sync is enabled but fails""" + mock_is_enabled.return_value = True + mock_sync.side_effect = Exception("Sync failed") + + response = self.client.get(reverse("v1:user-enrollments-api-list")) + + assert response.status_code == status.HTTP_200_OK + mock_sync.assert_called_once_with(self.user) + mock_log.assert_called_once_with("Failed to sync user enrollments with edX") + + @patch("courses.views.v1.sync_enrollments_with_edx") + @patch("courses.views.v1.is_enabled") + def test_list_with_sync_disabled(self, mock_is_enabled, mock_sync): + """Test list method when sync is disabled""" + mock_is_enabled.return_value = False + + response = self.client.get(reverse("v1:user-enrollments-api-list")) + + assert response.status_code == status.HTTP_200_OK + mock_sync.assert_not_called()