Skip to content

Commit fcb8e24

Browse files
committed
add argMax aggregation
1 parent 03b9adb commit fcb8e24

File tree

3 files changed

+38
-1
lines changed

3 files changed

+38
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
### 1.5.0
2+
- feat: add `argMax` aggregation https://clickhouse.com/docs/sql-reference/aggregate-functions/reference/argmax
23
- feat: #133: Fix simultaneous queries error when iteration is interrupted
34
- feat: #130: Add `distributed_migrations` database setting to support distributed migration queries.
45
- feat: #129: Add `toYearWeek` datetime functionality

clickhouse_backend/models/aggregates.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from django.db.models import aggregates
1+
from django.db.models import aggregates, fields
22
from django.db.models.expressions import Star
33

44
from clickhouse_backend.models.fields import UInt64Field
@@ -67,3 +67,24 @@ class uniqTheta(uniq):
6767

6868
class anyLast(Aggregate):
6969
pass
70+
71+
72+
class ArgMax(Aggregate):
73+
function = "argMax"
74+
name = "ArgMax"
75+
arity = 2
76+
77+
def __init__(self, value_expr, order_by_expr, **extra):
78+
if "output_field" not in extra:
79+
# Infer output_field from value_expr
80+
if hasattr(value_expr, "output_field"):
81+
extra["output_field"] = value_expr.output_field
82+
else:
83+
# Fallback: assume CharField
84+
extra["output_field"] = fields.CharField()
85+
expressions = [value_expr, order_by_expr]
86+
super().__init__(*expressions, **extra)
87+
88+
def as_sql(self, compiler, connection, **extra_context):
89+
self.extra["template"] = "%(function)s(%(expressions)s)"
90+
return super().as_sql(compiler, connection, **extra_context)

tests/aggregates/tests.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
uniqHLL12,
1111
uniqTheta,
1212
)
13+
from clickhouse_backend.models.aggregates import ArgMax
1314

1415
from .models import WatchSeries
1516

@@ -170,6 +171,20 @@ def setUpTestData(cls):
170171
# Use bulk_create to insert the list of objects in a single query
171172
WatchSeries.objects.bulk_create(watch_series_list)
172173

174+
def test_argMax(self):
175+
result = (
176+
WatchSeries.objects.values("show")
177+
.annotate(episode=ArgMax("episode", "date_id"))
178+
.order_by("show")
179+
)
180+
181+
expected_result = [
182+
{"show": "Bridgerton", "episode": "S1E1"},
183+
{"show": "Game of Thrones", "episode": "S1E1"},
184+
]
185+
186+
self.assertQuerysetEqual(result, expected_result, transform=dict)
187+
173188
def _test_uniq(self, cls_uniq):
174189
result = (
175190
WatchSeries.objects.values("show", "episode")

0 commit comments

Comments
 (0)