Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 168 additions & 1 deletion courses/views/v1/views_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,20 @@

# pylint: disable=unused-argument, redefined-outer-name, too-many-arguments
import operator as op
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
from urllib.parse import quote

import pytest
import reversion
from django.db.models import Count, Q
from django.test import RequestFactory
from django.test.client import Client
from django.urls import reverse
from requests import ConnectionError as RequestsConnectionError
from requests import HTTPError
from rest_framework import status
from rest_framework.test import APIClient
from reversion.models import Version

from courses.constants import ENROLL_CHANGE_STATUS_UNENROLLED
Expand All @@ -24,7 +28,9 @@
CourseRunFactory,
)
from courses.models import (
Course,
CourseRun,
Program,
ProgramEnrollment,
)
from courses.serializers.v1.courses import (
Expand All @@ -38,7 +44,11 @@
num_queries_from_course,
num_queries_from_programs,
)
from courses.views.v1 import UserEnrollmentsApiViewSet
from courses.views.v1 import (
CourseFilterSet,
ProgramViewSet,
UserEnrollmentsApiViewSet,
)
from ecommerce.factories import LineFactory, OrderFactory, ProductFactory
from ecommerce.models import Order, OrderStatus
from main import features
Expand All @@ -51,6 +61,7 @@
from main.test_utils import assert_drf_json_equal, duplicate_queries_check
from main.utils import encode_json_cookie_value
from openedx.exceptions import NoEdxApiAuthError
from users.factories import UserFactory

pytestmark = [pytest.mark.django_db]

Expand Down Expand Up @@ -808,3 +819,159 @@ def test_create_enrollments_with_existing_fulfilled_order(
else:
assert Order.objects.filter(state=OrderStatus.PENDING).count() == 1
patched_create_enrollments.assert_called_once()


@pytest.mark.django_db
class TestCourseFilterSet:
"""Test CourseFilterSet filtering methods"""

def test_filter_courserun_is_enrollable_true(self):
"""Test filtering for enrollable courses"""
# Create courses with different enrollment status
enrollable_course = CourseFactory.create(live=True)
# Create an enrollable course run: live=True, has enrollment_start in past, no enrollment_end, has start_date
CourseRunFactory.create(
course=enrollable_course,
live=True,
enrollment_start=datetime(2020, 1, 1, tzinfo=timezone.utc),
enrollment_end=None,
start_date=datetime(2020, 1, 15, tzinfo=timezone.utc),
)

non_enrollable_course = CourseFactory.create(live=True)
CourseRunFactory.create(course=non_enrollable_course, live=False)

queryset = Course.objects.all()
filterset = CourseFilterSet()

# Test filtering for enrollable courses
result = filterset.filter_courserun_is_enrollable(queryset, None, value=True)
assert enrollable_course in result

def test_filter_courserun_is_enrollable_false(self):
"""Test filtering for non-enrollable courses"""
# Create courses with different enrollment status
enrollable_course = CourseFactory.create(live=True)
# Create an enrollable course run: live=True, has enrollment_start in past, no enrollment_end, has start_date
CourseRunFactory.create(
course=enrollable_course,
live=True,
enrollment_start=datetime(2020, 1, 1, tzinfo=timezone.utc),
enrollment_end=None,
start_date=datetime(2020, 1, 15, tzinfo=timezone.utc),
)

non_enrollable_course = CourseFactory.create(live=True)
CourseRunFactory.create(course=non_enrollable_course, live=False)

queryset = Course.objects.all()
filterset = CourseFilterSet()

# Test filtering for non-enrollable courses
result = filterset.filter_courserun_is_enrollable(queryset, None, value=False)
assert non_enrollable_course in result


@pytest.mark.django_db
class TestProgramViewSetPagination:
"""Test ProgramViewSet pagination edge cases"""

def setup_method(self):
self.factory = RequestFactory()
self.user = UserFactory.create()
self.viewset = ProgramViewSet()

def test_paginate_queryset_no_page_param(self):
"""Test pagination when no page parameter is provided"""
# Create a mock request without page parameter
request = self.factory.get("/api/v1/programs/")
request.user = self.user
request.query_params = {}

# Set up the viewset
self.viewset.request = request

# Create a mock pagination class instance
mock_pagination_class = MagicMock()
# Mock the pagination_class attribute to return our mock
with patch.object(self.viewset, "pagination_class", mock_pagination_class):
queryset = Program.objects.all()
result = self.viewset.paginate_queryset(queryset)

# Should return None when no page param and paginator exists
assert result is None

def test_paginate_queryset_with_page_param(self):
"""Test pagination when page parameter is provided"""
request = self.factory.get("/api/v1/programs/?page=1")
request.user = self.user
request.query_params = {"page": "1"}

self.viewset.request = request

queryset = Program.objects.all()

# Mock both pagination_class and the parent's paginate_queryset method
mock_pagination_class = MagicMock()
with (
patch.object(self.viewset, "pagination_class", mock_pagination_class),
patch(
"rest_framework.viewsets.ReadOnlyModelViewSet.paginate_queryset"
) as mock_super,
):
mock_super.return_value = "paginated_result"
result = self.viewset.paginate_queryset(queryset)

mock_super.assert_called_once_with(queryset)
assert result == "paginated_result"


@pytest.mark.django_db
class TestUserEnrollmentsApiViewSetSync:
"""Test UserEnrollmentsApiViewSet sync functionality"""

def setup_method(self):
self.factory = RequestFactory()
self.user = UserFactory.create()
self.client = APIClient()
self.client.force_authenticate(user=self.user)
self.viewset = UserEnrollmentsApiViewSet()

@patch("courses.views.v1.sync_enrollments_with_edx")
@patch("courses.views.v1.is_enabled")
def test_list_with_sync_enabled_success(self, mock_is_enabled, mock_sync):
"""Test list method when sync is enabled and succeeds"""
mock_is_enabled.return_value = True
mock_sync.return_value = None

response = self.client.get(reverse("v1:user-enrollments-api-list"))

assert response.status_code == status.HTTP_200_OK
mock_sync.assert_called_once_with(self.user)

@patch("courses.views.v1.sync_enrollments_with_edx")
@patch("courses.views.v1.is_enabled")
@patch("courses.views.v1.log.exception")
def test_list_with_sync_enabled_exception(
self, mock_log, mock_is_enabled, mock_sync
):
"""Test list method when sync is enabled but fails"""
mock_is_enabled.return_value = True
mock_sync.side_effect = Exception("Sync failed")

response = self.client.get(reverse("v1:user-enrollments-api-list"))

assert response.status_code == status.HTTP_200_OK
mock_sync.assert_called_once_with(self.user)
mock_log.assert_called_once_with("Failed to sync user enrollments with edX")

@patch("courses.views.v1.sync_enrollments_with_edx")
@patch("courses.views.v1.is_enabled")
def test_list_with_sync_disabled(self, mock_is_enabled, mock_sync):
"""Test list method when sync is disabled"""
mock_is_enabled.return_value = False

response = self.client.get(reverse("v1:user-enrollments-api-list"))

assert response.status_code == status.HTTP_200_OK
mock_sync.assert_not_called()