Skip to content
This repository was archived by the owner on Jun 13, 2025. It is now read-only.

Commit ae546ed

Browse files
committed
fix: address feedback
1 parent 0efd853 commit ae546ed

File tree

3 files changed

+54
-32
lines changed

3 files changed

+54
-32
lines changed

graphql_api/types/flake_aggregates/flake_aggregates.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from graphql import GraphQLResolveInfo
66
from shared.django_apps.core.models import Repository
77

8+
from graphql_api.types.enums.enum_types import MeasurementInterval
89
from utils.test_results import get_results
910

1011

@@ -48,13 +49,17 @@ def flake_aggregates_with_percentage(
4849
return FlakeAggregates(**aggregates)
4950

5051

51-
def generate_flake_aggregates(repoid: int, interval: int) -> FlakeAggregates | None:
52+
def generate_flake_aggregates(
53+
repoid: int, interval: MeasurementInterval
54+
) -> FlakeAggregates | None:
5255
repo = Repository.objects.get(repoid=repoid)
5356

54-
curr_results = get_results(repo.repoid, repo.branch, interval)
57+
curr_results = get_results(repo.repoid, repo.branch, interval.value)
5558
if curr_results is None:
5659
return None
57-
past_results = get_results(repo.repoid, repo.branch, interval * 2, interval)
60+
past_results = get_results(
61+
repo.repoid, repo.branch, interval.value * 2, interval.value
62+
)
5863
if past_results is None:
5964
return flake_aggregates_from_table(curr_results)
6065
else:

graphql_api/types/test_analytics/test_analytics.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929

3030
log = logging.getLogger(__name__)
3131

32+
INTERVAL_30_DAY = 30
33+
INTERVAL_7_DAY = 7
34+
INTERVAL_1_DAY = 1
35+
3236

3337
@dataclass
3438
class TestResultsRow:
@@ -109,7 +113,7 @@ def validate(
109113
first: int | None,
110114
last: int | None,
111115
) -> None:
112-
if interval not in {1, 7, 30}:
116+
if interval not in {INTERVAL_1_DAY, INTERVAL_7_DAY, INTERVAL_30_DAY}:
113117
raise ValidationError(f"Invalid interval: {interval}")
114118

115119
if not isinstance(ordering_direction, OrderingDirection):
@@ -125,6 +129,22 @@ def validate(
125129
raise ValidationError("After and before can not be used at the same time")
126130

127131

132+
def ordering_expression(
133+
ordering: TestResultsOrderingParameter, cursor_value: CursorValue, is_forward: bool
134+
) -> pl.Expr:
135+
if is_forward:
136+
ordering_expression = (pl.col(ordering.value) > cursor_value.ordered_value) | (
137+
(pl.col(ordering.value) == cursor_value.ordered_value)
138+
& (pl.col("name") > cursor_value.name)
139+
)
140+
else:
141+
ordering_expression = (pl.col(ordering.value) < cursor_value.ordered_value) | (
142+
(pl.col(ordering.value) == cursor_value.ordered_value)
143+
& (pl.col("name") > cursor_value.name)
144+
)
145+
return ordering_expression
146+
147+
128148
def generate_test_results(
129149
ordering: TestResultsOrderingParameter,
130150
ordering_direction: OrderingDirection,
@@ -182,7 +202,8 @@ def generate_test_results(
182202

183203
if flags:
184204
table = table.filter(
185-
pl.col("flags").list.eval(pl.element().is_in(flags)).list.any()
205+
pl.col("flags").is_not_null()
206+
& pl.col("flags").list.eval(pl.element().is_in(flags)).list.any()
186207
)
187208

188209
match parameter:
@@ -201,23 +222,6 @@ def generate_test_results(
201222

202223
total_count = table.height
203224

204-
def ordering_expression(cursor_value: CursorValue, is_forward: bool) -> pl.Expr:
205-
if is_forward:
206-
ordering_expression = (
207-
pl.col(ordering.value) > cursor_value.ordered_value
208-
) | (
209-
(pl.col(ordering.value) == cursor_value.ordered_value)
210-
& (pl.col("name") > cursor_value.name)
211-
)
212-
else:
213-
ordering_expression = (
214-
pl.col(ordering.value) < cursor_value.ordered_value
215-
) | (
216-
(pl.col(ordering.value) == cursor_value.ordered_value)
217-
& (pl.col("name") > cursor_value.name)
218-
)
219-
return ordering_expression
220-
221225
if after:
222226
if ordering_direction == OrderingDirection.ASC:
223227
is_forward = True
@@ -226,7 +230,9 @@ def ordering_expression(cursor_value: CursorValue, is_forward: bool) -> pl.Expr:
226230

227231
cursor_value = decode_cursor(after, ordering)
228232
if cursor_value:
229-
table = table.filter(ordering_expression(cursor_value, is_forward))
233+
table = table.filter(
234+
ordering_expression(ordering, cursor_value, is_forward)
235+
)
230236
elif before:
231237
if ordering_direction == OrderingDirection.DESC:
232238
is_forward = True
@@ -235,7 +241,9 @@ def ordering_expression(cursor_value: CursorValue, is_forward: bool) -> pl.Expr:
235241

236242
cursor_value = decode_cursor(before, ordering)
237243
if cursor_value:
238-
table = table.filter(ordering_expression(cursor_value, is_forward))
244+
table = table.filter(
245+
ordering_expression(ordering, cursor_value, is_forward)
246+
)
239247

240248
table = table.sort(
241249
[ordering.value, "name"],
@@ -363,7 +371,8 @@ async def resolve_test_results_aggregates(
363371
**_: Any,
364372
) -> TestResultsAggregates | None:
365373
return await sync_to_async(generate_test_results_aggregates)(
366-
repoid=repository.repoid, interval=interval.value if interval else 30
374+
repoid=repository.repoid,
375+
interval=interval if interval else MeasurementInterval.INTERVAL_30_DAY,
367376
)
368377

369378

@@ -375,7 +384,8 @@ async def resolve_flake_aggregates(
375384
**_: Any,
376385
) -> FlakeAggregates | None:
377386
return await sync_to_async(generate_flake_aggregates)(
378-
repoid=repository.repoid, interval=interval.value if interval else 30
387+
repoid=repository.repoid,
388+
interval=interval if interval else MeasurementInterval.INTERVAL_30_DAY,
379389
)
380390

381391

graphql_api/types/test_results_aggregates/test_results_aggregates.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from shared.django_apps.core.models import Repository
77

88
from utils.test_results import get_results
9+
from graphql_api.types.enums.enum_types import MeasurementInterval
910

1011

1112
@dataclass
@@ -43,9 +44,11 @@ def calculate_aggregates(table: pl.DataFrame) -> pl.DataFrame:
4344
),
4445
(pl.col("total_skip_count").sum()).alias("skips"),
4546
(pl.col("total_fail_count").sum()).alias("fails"),
46-
((pl.col("avg_duration") >= pl.col("avg_duration").quantile(0.95)).sum()).alias(
47-
"total_slow_tests"
48-
),
47+
(
48+
(pl.col("avg_duration") >= pl.col("avg_duration").quantile(0.95))
49+
.top_k(100)
50+
.sum()
51+
).alias("total_slow_tests"),
4952
)
5053

5154

@@ -65,6 +68,8 @@ def test_results_aggregates_with_percentage(
6568

6669
merged_results: pl.DataFrame = pl.concat([past_aggregates, curr_aggregates])
6770

71+
# with_columns upserts the new columns, so if the name already exists it get overwritten
72+
# otherwise it's just added
6873
merged_results = merged_results.with_columns(
6974
pl.all().pct_change().name.suffix("_percent_change")
7075
)
@@ -74,14 +79,16 @@ def test_results_aggregates_with_percentage(
7479

7580

7681
def generate_test_results_aggregates(
77-
repoid: int, interval: int
82+
repoid: int, interval: MeasurementInterval
7883
) -> TestResultsAggregates | None:
7984
repo = Repository.objects.get(repoid=repoid)
8085

81-
curr_results = get_results(repo.repoid, repo.branch, interval)
86+
curr_results = get_results(repo.repoid, repo.branch, interval.value)
8287
if curr_results is None:
8388
return None
84-
past_results = get_results(repo.repoid, repo.branch, interval * 2, interval)
89+
past_results = get_results(
90+
repo.repoid, repo.branch, interval.value * 2, interval.value
91+
)
8592
if past_results is None:
8693
return test_results_aggregates_from_table(curr_results)
8794
else:

0 commit comments

Comments
 (0)