Skip to content

Commit 93a3ac0

Browse files
POC: endpoint to get evals info (#246)
* POC: endpoint to get evals info These changes add 2 new endpoints in the REST api (hidden behing a feature flag) with the objective of being a POC in reporting Eval info saved in Test Analytics. Probably not what anyone else had in mind, but a POC mind you. I didn't want to deal with building a frontend. Also the current way _we_ use evals is not in the CI, so I didn't think it made much sense to mess around the PR comment either. Anyway the endpoints are: * summary - aggregates eval data (that is saved per test-run). * compare - aggregate per commit and compare. * fix imports because tests exploded in CI * fix graphQL tests Apparently there was a custom field lookup in the Database router that was being used in the api. I moved to the shared one because the api one didn't know about the TA_TIMESERIES database the shared one didn't know about the special field I think it's OK to move to just 1, and that should probably be the shared one? Hopefully. At least tests pass now :E * typing, docstrings and readability * renames and logic change Now all the items are part of the aggregation calculation. And we export "sum", not only "avg"
1 parent c047c24 commit 93a3ac0

File tree

10 files changed

+796
-71
lines changed

10 files changed

+796
-71
lines changed

.vscode/settings.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,8 @@
1212
"apps/worker",
1313
"libs/shared",
1414
],
15+
"cSpell.words": [
16+
"Testrun",
17+
"testruns"
18+
],
1519
}
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
from typing import TypedDict
2+
3+
import django_filters
4+
from django.http import JsonResponse
5+
from django_filters.rest_framework import DjangoFilterBackend
6+
from drf_spectacular.types import OpenApiTypes
7+
from drf_spectacular.utils import OpenApiParameter, extend_schema
8+
from rest_framework import viewsets
9+
from rest_framework.decorators import action
10+
11+
from api.public.v2.schema import repo_parameters
12+
from api.shared.mixins import RepoPropertyMixin
13+
from api.shared.permissions import RepositoryArtifactPermissions
14+
from rollouts import READ_NEW_EVALS
15+
from shared.django_apps.db_settings import TA_TIMESERIES_ENABLED
16+
from shared.django_apps.ta_timeseries.models import Testrun
17+
18+
19+
class EvalsSummary(TypedDict):
20+
avgDurationSeconds: float
21+
avgCost: float
22+
totalItems: int
23+
passedItems: int
24+
failedItems: int
25+
scores: dict[str, dict[str, float]]
26+
27+
28+
class EvalsFilters(django_filters.FilterSet):
29+
commit = django_filters.CharFilter(field_name="commit_sha")
30+
classname = django_filters.CharFilter(field_name="classname")
31+
32+
class Meta:
33+
model = Testrun
34+
fields = ["commit", "classname"]
35+
36+
37+
class EvalsPermissions(RepositoryArtifactPermissions):
38+
"""
39+
Permissions class for evals endpoints. Extends RepositoryArtifactPermissions
40+
to add a check for the READ_NEW_EVALS feature flag.
41+
"""
42+
43+
def has_permission(self, request, view):
44+
# First check if the user has basic repository access
45+
has_basic_permission = super().has_permission(request, view)
46+
if not has_basic_permission:
47+
return False
48+
49+
# Then check if the repository has the feature flag enabled
50+
if not READ_NEW_EVALS.check_value(identifier=str(view.repo.repoid)):
51+
return False
52+
53+
# Finally, check if the environment has TA_TIMESERIES_ENABLED
54+
if not TA_TIMESERIES_ENABLED:
55+
return False
56+
57+
return True
58+
59+
60+
@extend_schema(
61+
parameters=repo_parameters,
62+
tags=["Evaluations"],
63+
)
64+
class EvalsViewSet(viewsets.GenericViewSet, RepoPropertyMixin):
65+
permission_classes = [EvalsPermissions]
66+
filter_backends = [DjangoFilterBackend]
67+
filterset_class = EvalsFilters
68+
69+
def get_queryset(self):
70+
return Testrun.objects.filter(
71+
repo_id=self.repo.repoid, properties__isnull=False
72+
)
73+
74+
def _aggregate_testruns(self, testruns) -> EvalsSummary:
75+
"""
76+
Aggregate metrics from a list of testruns.
77+
Returns a dict with aggregated metrics and scores.
78+
"""
79+
# TODO: This function loads all testruns into memory.
80+
# If possible we should offload the calculation to postgres.
81+
# (although if it ever get's out of the POC I'd expect the rollup to exist in a separate table)
82+
83+
total_items = len(testruns)
84+
passed_items = sum(1 for t in testruns if t.outcome == "pass")
85+
failed_items = total_items - passed_items
86+
87+
avg_duration = (
88+
sum(t.duration_seconds or 0 for t in testruns) / total_items
89+
if total_items > 0
90+
else 0
91+
)
92+
93+
# Calculate score sums and averages for all items with scores
94+
score_agg_data: dict[str, tuple[float, int]] = {}
95+
cost_acc = 0
96+
items_with_cost = 0
97+
98+
for testrun in testruns:
99+
eval_data = testrun.properties.get("eval", {})
100+
scores = eval_data.get("scores", [])
101+
cost = eval_data.get("cost")
102+
if cost:
103+
cost_acc += cost
104+
items_with_cost += 1
105+
106+
# Consider scores from all items (not just passed ones)
107+
for score in scores:
108+
name = score.get("name")
109+
if name:
110+
score_value = score.get("value") or score.get("score")
111+
if isinstance(score_value, int | float):
112+
if name not in score_agg_data:
113+
score_agg_data[name] = (0, 0)
114+
score_agg_data[name] = (
115+
score_agg_data[name][0] + score_value,
116+
score_agg_data[name][1] + 1,
117+
)
118+
119+
# Create score aggregation dicts with both sum and avg
120+
scores = {
121+
name: {"sum": _sum, "avg": _sum / count if count > 0 else 0}
122+
for name, (_sum, count) in score_agg_data.items()
123+
}
124+
125+
return {
126+
"avgDurationSeconds": avg_duration,
127+
"avgCost": cost_acc / items_with_cost if items_with_cost > 0 else 0,
128+
"totalItems": total_items,
129+
"passedItems": passed_items,
130+
"failedItems": failed_items,
131+
"scores": scores,
132+
}
133+
134+
@extend_schema(
135+
summary="Evaluation summary",
136+
parameters=[
137+
OpenApiParameter(
138+
"commit",
139+
OpenApiTypes.STR,
140+
OpenApiParameter.QUERY,
141+
description="commit SHA for which to return evaluation summary",
142+
),
143+
# "classname" is a terrible name but that's the name of the field in the testrun model
144+
# it is the name of the class that the test belongs to, or `describe` block in vitest
145+
# for langfuse it is the name of the run
146+
OpenApiParameter(
147+
"classname",
148+
OpenApiTypes.STR,
149+
OpenApiParameter.QUERY,
150+
description="class name the test belongs to, or `describe` block in vitest, or run name in langfuse",
151+
),
152+
],
153+
)
154+
@action(detail=False, methods=["get"])
155+
def summary(self, request, *args, **kwargs):
156+
"""
157+
Returns a summary of evaluations for the specified repository and commit
158+
"""
159+
queryset = self.filter_queryset(self.get_queryset())
160+
testruns = list(queryset)
161+
return JsonResponse(self._aggregate_testruns(testruns))
162+
163+
@extend_schema(
164+
summary="Evaluation compare",
165+
parameters=[
166+
OpenApiParameter(
167+
"base_sha",
168+
OpenApiTypes.STR,
169+
OpenApiParameter.QUERY,
170+
description="base commit SHA to compare from",
171+
),
172+
OpenApiParameter(
173+
"head_sha",
174+
OpenApiTypes.STR,
175+
OpenApiParameter.QUERY,
176+
description="head commit SHA to compare to",
177+
),
178+
],
179+
)
180+
@action(detail=False, methods=["get"])
181+
def compare(self, request, *args, **kwargs):
182+
"""
183+
Returns a comparison of evaluations between two commits
184+
"""
185+
base_sha = request.query_params.get("base_sha")
186+
head_sha = request.query_params.get("head_sha")
187+
188+
if not base_sha or not head_sha:
189+
return JsonResponse(
190+
{"error": "Both base_sha and head_sha are required"}, status=400
191+
)
192+
193+
# Get testruns for both commits
194+
base_testruns = list(self.get_queryset().filter(commit_sha=base_sha))
195+
head_testruns = list(self.get_queryset().filter(commit_sha=head_sha))
196+
197+
base_data = self._aggregate_testruns(base_testruns)
198+
head_data = self._aggregate_testruns(head_testruns)
199+
200+
# Calculate differences
201+
def calculate_diff(base, head):
202+
if base == 0:
203+
return 0 if head == 0 else 100
204+
return ((head - base) / base) * 100
205+
206+
# Compare scores
207+
score_diffs = {}
208+
all_score_names = set(base_data["scores"].keys()) | set(
209+
head_data["scores"].keys()
210+
)
211+
for score_name in all_score_names:
212+
base_score_data = base_data["scores"].get(score_name, {"sum": 0, "avg": 0})
213+
head_score_data = head_data["scores"].get(score_name, {"sum": 0, "avg": 0})
214+
215+
score_diffs[score_name] = {
216+
"sum": calculate_diff(base_score_data["sum"], head_score_data["sum"]),
217+
"avg": calculate_diff(base_score_data["avg"], head_score_data["avg"]),
218+
}
219+
220+
return JsonResponse(
221+
{
222+
"base": base_data,
223+
"head": head_data,
224+
"diff": {
225+
"avgDurationSeconds": calculate_diff(
226+
base_data["avgDurationSeconds"], head_data["avgDurationSeconds"]
227+
),
228+
"avgCost": calculate_diff(
229+
base_data["avgCost"], head_data["avgCost"]
230+
),
231+
"totalItems": head_data["totalItems"] - base_data["totalItems"],
232+
"passedItems": head_data["passedItems"] - base_data["passedItems"],
233+
"failedItems": head_data["failedItems"] - base_data["failedItems"],
234+
"scores": score_diffs,
235+
},
236+
}
237+
)

0 commit comments

Comments
 (0)