diff --git a/api/public/v1/serializers.py b/api/public/v1/serializers.py index ad95db7558..2f783694c4 100644 --- a/api/public/v1/serializers.py +++ b/api/public/v1/serializers.py @@ -16,3 +16,7 @@ class Meta: "state", ) fields = read_only_fields + ("user_provided_base_sha",) + + +class PullIdSerializer(serializers.Serializer): + pullid = serializers.IntegerField() diff --git a/api/public/v1/tests/views/test_pull_viewset.py b/api/public/v1/tests/views/test_pull_viewset.py index cdbda1fffb..f83e2b4386 100644 --- a/api/public/v1/tests/views/test_pull_viewset.py +++ b/api/public/v1/tests/views/test_pull_viewset.py @@ -133,3 +133,10 @@ def test_post_pull_user_provided_base(self, pulls_sync_mock): ) self.assertEqual(response.status_code, 405) assert not pulls_sync_mock.called + + def test_get_pull_no_pullid_provided(self): + self.client.credentials(HTTP_AUTHORIZATION="Token " + self.repo.upload_token) + response = self.client.get("/api/github/codecov/testRepoName/pulls/abc") + self.assertEqual(response.status_code, 400) + content = json.loads(response.content.decode()) + self.assertEqual(content["pullid"], ["A valid integer is required."]) diff --git a/api/public/v1/views.py b/api/public/v1/views.py index 37a45ec68c..b2e8546812 100644 --- a/api/public/v1/views.py +++ b/api/public/v1/views.py @@ -1,17 +1,17 @@ import logging -from django.db.models import OuterRef, Subquery +from django.db.models import OuterRef, QuerySet, Subquery from django.shortcuts import get_object_or_404 from django_filters.rest_framework import DjangoFilterBackend from rest_framework import filters, mixins, viewsets from api.shared.mixins import RepoPropertyMixin from codecov_auth.authentication.repo_auth import RepositoryLegacyTokenAuthentication -from core.models import Commit +from core.models import Commit, Pull from services.task import TaskService from .permissions import PullUpdatePermission -from .serializers import PullSerializer +from .serializers import PullIdSerializer, PullSerializer log = logging.getLogger(__name__) @@ -30,8 +30,11 @@ class PullViewSet( authentication_classes = [RepositoryLegacyTokenAuthentication] permission_classes = [PullUpdatePermission] - def get_object(self): - pullid = self.kwargs.get("pk") + def get_object(self) -> Pull: + serializer = PullIdSerializer(data={"pullid": self.kwargs.get("pk")}) + serializer.is_valid(raise_exception=True) + pullid = serializer.validated_data["pullid"] + if self.request.method == "PUT": # Note: We create a new pull if needed to make sure that they can be updated # with a base before the upload has finished processing. @@ -41,7 +44,7 @@ def get_object(self): return obj return get_object_or_404(self.get_queryset(), pullid=pullid) - def get_queryset(self): + def get_queryset(self) -> QuerySet: return self.repo.pull_requests.annotate( ci_passed=Subquery( Commit.objects.filter( @@ -50,7 +53,7 @@ def get_queryset(self): ) ) - def perform_update(self, serializer): + def perform_update(self, serializer: PullSerializer) -> Pull: result = super().perform_update(serializer) TaskService().pulls_sync(repoid=self.repo.repoid, pullid=self.kwargs.get("pk")) return result diff --git a/api/public/v2/compare/serializers.py b/api/public/v2/compare/serializers.py index 043fdf1ba3..e85843fe66 100644 --- a/api/public/v2/compare/serializers.py +++ b/api/public/v2/compare/serializers.py @@ -16,10 +16,11 @@ class ComparisonSerializer(BaseComparisonSerializer): def get_files(self, comparison: Comparison) -> List[dict]: data = [] - for filename in comparison.head_report.files: - file = comparison.get_file_comparison(filename, bypass_max_diff=True) - if self._should_include_file(file): - data.append(FileComparisonSerializer(file).data) + if comparison.head_report is not None: + for filename in comparison.head_report.files: + file = comparison.get_file_comparison(filename, bypass_max_diff=True) + if self._should_include_file(file): + data.append(FileComparisonSerializer(file).data) return data