Skip to content
This repository was archived by the owner on Jun 13, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion api/public/v2/pull/serializers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from typing import Dict, Optional

from rest_framework import serializers

from api.public.v2.owner.serializers import OwnerSerializer
from api.shared.commit.serializers import CommitTotalsSerializer
from api.shared.commit.serializers import (
CommitTotalsSerializer,
PatchCoverageSerializer,
)
from core.models import Pull, PullStates
from services.comparison import CommitComparisonService, ComparisonReport


class PullSerializer(serializers.ModelSerializer):
Expand All @@ -18,6 +24,7 @@ class PullSerializer(serializers.ModelSerializer):
label="indicates whether the CI process passed for the head commit of this pull"
)
author = OwnerSerializer(label="pull author")
patch = serializers.SerializerMethodField()

class Meta:
model = Pull
Expand All @@ -30,5 +37,30 @@ class Meta:
"state",
"ci_passed",
"author",
"patch",
)
fields = read_only_fields

def get_patch(self, obj: Pull) -> Optional[Dict[str, float]]:
commit_comparison = CommitComparisonService.get_commit_comparison_for_pull(obj)
if not commit_comparison or not commit_comparison.is_processed:
return None
cr = ComparisonReport(commit_comparison)
hits = misses = partials = 0
for f in cr.impacted_files:
pc = f.patch_coverage
if pc:
hits += pc.hits
misses += pc.misses
partials += pc.partials
total_branches = hits + misses + partials
coverage = 0.0
if total_branches != 0:
coverage = round(100 * hits / total_branches, 2)
data = dict(
hits=hits,
misses=misses,
partials=partials,
coverage=coverage,
)
return PatchCoverageSerializer(data).data
146 changes: 130 additions & 16 deletions api/public/v2/tests/test_api_pull_viewset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import patch
from unittest.mock import MagicMock, patch

from django.test import override_settings
from django.urls import reverse
Expand All @@ -20,7 +20,8 @@ def setUp(self):
self.org = OwnerFactory()
self.repo = RepositoryFactory(author=self.org)
self.current_owner = OwnerFactory(
permission=[self.repo.repoid], organizations=[self.org.ownerid]
permission=[self.repo.repoid],
organizations=[self.org.ownerid],
)
self.pulls = [
PullFactory(repository=self.repo),
Expand All @@ -29,11 +30,13 @@ def setUp(self):
Pull.objects.filter(pk=self.pulls[1].pk).update(
updatestamp="2023-01-01T00:00:00"
)

self.client = APIClient()
self.client.force_login_owner(self.current_owner)
self.no_patch_response = dict(hits=0, misses=0, partials=0, coverage=0.0)

def test_list(self):
@patch("api.public.v2.pull.serializers.PullSerializer.get_patch")
def test_list(self, mock_patch):
mock_patch.return_value = self.no_patch_response
res = self.client.get(
reverse(
"api-v2-pulls-list",
Expand All @@ -59,6 +62,7 @@ def test_list(self):
"state": "open",
"ci_passed": None,
"author": None,
"patch": {"hits": 0, "misses": 0, "partials": 0, "coverage": 0.0},
},
{
"pullid": self.pulls[0].pullid,
Expand All @@ -69,12 +73,15 @@ def test_list(self):
"state": "open",
"ci_passed": None,
"author": None,
"patch": {"hits": 0, "misses": 0, "partials": 0, "coverage": 0.0},
},
],
"total_pages": 1,
}

def test_list_state(self):
@patch("api.public.v2.pull.serializers.PullSerializer.get_patch")
def test_list_state(self, mock_patch):
mock_patch.return_value = self.no_patch_response
pull = PullFactory(repository=self.repo, state="closed")
url = reverse(
"api-v2-pulls-list",
Expand All @@ -100,12 +107,15 @@ def test_list_state(self):
"state": "closed",
"ci_passed": None,
"author": None,
},
"patch": {"hits": 0, "misses": 0, "partials": 0, "coverage": 0.0},
}
],
"total_pages": 1,
}

def test_list_start_date(self):
@patch("api.public.v2.pull.serializers.PullSerializer.get_patch")
def test_list_start_date(self, mock_patch):
mock_patch.return_value = self.no_patch_response
url = reverse(
"api-v2-pulls-list",
kwargs={
Expand All @@ -130,12 +140,15 @@ def test_list_start_date(self):
"state": "open",
"ci_passed": None,
"author": None,
},
"patch": {"hits": 0, "misses": 0, "partials": 0, "coverage": 0.0},
}
],
"total_pages": 1,
}

def test_list_cursor_pagination(self):
@patch("api.public.v2.pull.serializers.PullSerializer.get_patch")
def test_list_cursor_pagination(self, mock_patch):
mock_patch.return_value = self.no_patch_response
url = reverse(
"api-v2-pulls-list",
kwargs={
Expand All @@ -157,7 +170,8 @@ def test_list_cursor_pagination(self):
"state": "open",
"ci_passed": None,
"author": None,
},
"patch": {"hits": 0, "misses": 0, "partials": 0, "coverage": 0.0},
}
]
assert data["previous"] is None
assert data["next"] is not None
Expand All @@ -174,15 +188,17 @@ def test_list_cursor_pagination(self):
"state": "open",
"ci_passed": None,
"author": None,
},
"patch": {"hits": 0, "misses": 0, "partials": 0, "coverage": 0.0},
}
]
assert data["previous"] is not None
assert data["next"] is None

@patch("api.public.v2.pull.serializers.PullSerializer.get_patch")
@patch("api.shared.repo.repository_accessors.RepoAccessors.get_repo_permissions")
def test_retrieve(self, get_repo_permissions):
def test_retrieve(self, get_repo_permissions, mock_patch):
mock_patch.return_value = self.no_patch_response
get_repo_permissions.return_value = (True, True)

res = self.client.get(
reverse(
"api-v2-pulls-detail",
Expand All @@ -204,6 +220,7 @@ def test_retrieve(self, get_repo_permissions):
"state": "open",
"ci_passed": None,
"author": None,
"patch": {"hits": 0, "misses": 0, "partials": 0, "coverage": 0.0},
}

@patch("api.shared.permissions.RepositoryArtifactPermissions.has_permission")
Expand All @@ -215,7 +232,6 @@ def test_no_pull_if_unauthenticated_token_request(
):
repository_artifact_permissions_has_permission.return_value = False
super_token_permissions_has_permission.return_value = False

res = self.client.get(
reverse(
"api-v2-pulls-detail",
Expand All @@ -238,7 +254,6 @@ def test_no_pull_if_not_super_token_nor_user_token(
self, repository_artifact_permissions_has_permission
):
repository_artifact_permissions_has_permission.return_value = False

res = self.client.get(
reverse(
"api-v2-pulls-detail",
Expand Down Expand Up @@ -278,7 +293,9 @@ def test_no_pull_if_super_token_but_no_GET_request(
)

@override_settings(SUPER_API_TOKEN="testaxs3o76rdcdpfzexuccx3uatui2nw73r")
def test_pull_with_valid_super_token(self):
@patch("api.public.v2.pull.serializers.PullSerializer.get_patch")
def test_pull_with_valid_super_token(self, mock_patch):
mock_patch.return_value = self.no_patch_response
res = self.client.get(
reverse(
"api-v2-pulls-detail",
Expand All @@ -301,4 +318,101 @@ def test_pull_with_valid_super_token(self):
"state": "open",
"ci_passed": None,
"author": None,
"patch": {"hits": 0, "misses": 0, "partials": 0, "coverage": 0.0},
}

@patch("api.public.v2.pull.serializers.ComparisonReport")
@patch("services.comparison.CommitComparison.objects.filter")
def test_retrieve_with_patch_coverage(self, mock_cc_filter, mock_comparison_report):
mock_cc_instance = MagicMock(is_processed=True)
mock_cc_filter.return_value.select_related.return_value.first.return_value = (
mock_cc_instance
)

mock_file = MagicMock()
mock_file.patch_coverage.hits = 10
mock_file.patch_coverage.misses = 5
mock_file.patch_coverage.partials = 2
mock_comparison_report.return_value.impacted_files = [mock_file]

res = self.client.get(
reverse(
"api-v2-pulls-detail",
kwargs={
"service": self.org.service,
"owner_username": self.org.username,
"repo_name": self.repo.name,
"pullid": self.pulls[0].pullid,
},
)
)
assert res.status_code == 200
data = res.json()
assert data["patch"] == {
"hits": 10,
"misses": 5,
"partials": 2,
"coverage": 58.82,
}

@patch("api.public.v2.pull.serializers.ComparisonReport")
@patch("services.comparison.CommitComparison.objects.filter")
def test_retrieve_with_patch_coverage_no_branches(
self, mock_cc_filter, mock_comparison_report
):
mock_cc_instance = MagicMock(is_processed=True)
mock_cc_filter.return_value.select_related.return_value.first.return_value = (
mock_cc_instance
)

mock_file = MagicMock()
mock_file.patch_coverage.hits = 0
mock_file.patch_coverage.misses = 0
mock_file.patch_coverage.partials = 0
mock_comparison_report.return_value.impacted_files = [mock_file]

res = self.client.get(
reverse(
"api-v2-pulls-detail",
kwargs={
"service": self.org.service,
"owner_username": self.org.username,
"repo_name": self.repo.name,
"pullid": self.pulls[0].pullid,
},
)
)
assert res.status_code == 200
data = res.json()
assert data["patch"] == self.no_patch_response

@patch("api.public.v2.pull.serializers.ComparisonReport")
@patch("services.comparison.CommitComparison.objects.filter")
def test_retrieve_with_patch_coverage_no_commit_comparison(
self, mock_cc_filter, mock_comparison_report
):
mock_cc_instance = MagicMock(is_processed=False)
mock_cc_filter.return_value.select_related.return_value.first.return_value = (
mock_cc_instance
)

mock_file = MagicMock()
mock_file.patch_coverage.hits = 0
mock_file.patch_coverage.misses = 0
mock_file.patch_coverage.partials = 0
mock_comparison_report.return_value.impacted_files = [mock_file]

res = self.client.get(
reverse(
"api-v2-pulls-detail",
kwargs={
"service": self.org.service,
"owner_username": self.org.username,
"repo_name": self.repo.name,
"pullid": self.pulls[0].pullid,
},
)
)
assert res.status_code == 200
data = res.json()
assert data["patch"] is None
7 changes: 7 additions & 0 deletions api/shared/commit/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ def get_coverage(self, totals) -> float:
return 0


class PatchCoverageSerializer(serializers.Serializer):
hits = serializers.IntegerField()
misses = serializers.IntegerField()
partials = serializers.IntegerField()
coverage = serializers.FloatField()


class CommitTotalsSerializer(BaseTotalsSerializer):
files = serializers.IntegerField(source="f")
lines = serializers.IntegerField(source="n")
Expand Down
15 changes: 13 additions & 2 deletions services/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from shared.utils.merge import LineType, line_type

from compare.models import CommitComparison
from core.models import Commit
from core.models import Commit, Pull
from reports.models import CommitReport
from services import ServiceException
from services.redis_configuration import get_redis_connection
Expand Down Expand Up @@ -1186,7 +1186,8 @@ def update_base_report_with_pseudo_diff(self):
class CommitComparisonService:
"""
Utilities for determining whether a commit comparison needs to be recomputed
(and enqueueing that recompute when necessary)
(and enqueueing that recompute when necessary), and fetching associated comparisons
for pulls
"""

def __init__(self, commit_comparison: CommitComparison):
Expand Down Expand Up @@ -1246,6 +1247,16 @@ def _load_commit(self, commit_id: int) -> Optional[Commit]:
.first()
)

@staticmethod
def get_commit_comparison_for_pull(obj: Pull) -> Optional[CommitComparison]:
comparison_qs = CommitComparison.objects.filter(
base_commit__commitid=obj.compared_to,
compare_commit__commitid=obj.head,
base_commit__repository_id=obj.repository_id,
compare_commit__repository_id=obj.repository_id,
).select_related("compare_commit", "base_commit")
return comparison_qs.first()

@classmethod
def fetch_precomputed(self, repo_id: int, keys: List[Tuple]) -> QuerySet:
comparison_table = CommitComparison._meta.db_table
Expand Down
Loading