diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b5d361..ce9e3ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ ### 1.5.0 +- feat: #134: add `argMax` aggregation https://clickhouse.com/docs/sql-reference/aggregate-functions/reference/argmax - feat: #133: Fix simultaneous queries error when iteration is interrupted - feat: #130: Add `distributed_migrations` database setting to support distributed migration queries. - feat: #129: Add `toYearWeek` datetime functionality diff --git a/clickhouse_backend/models/aggregates.py b/clickhouse_backend/models/aggregates.py index b89d1d0..ed8dfea 100644 --- a/clickhouse_backend/models/aggregates.py +++ b/clickhouse_backend/models/aggregates.py @@ -1,4 +1,4 @@ -from django.db.models import aggregates +from django.db.models import aggregates, fields from django.db.models.expressions import Star from clickhouse_backend.models.fields import UInt64Field @@ -67,3 +67,24 @@ class uniqTheta(uniq): class anyLast(Aggregate): pass + + +class ArgMax(Aggregate): + function = "argMax" + name = "ArgMax" + arity = 2 + + def __init__(self, value_expr, order_by_expr, **extra): + if "output_field" not in extra: + # Infer output_field from value_expr + if hasattr(value_expr, "output_field"): + extra["output_field"] = value_expr.output_field + else: + # Fallback: assume CharField + extra["output_field"] = fields.CharField() + expressions = [value_expr, order_by_expr] + super().__init__(*expressions, **extra) + + def as_sql(self, compiler, connection, **extra_context): + self.extra["template"] = "%(function)s(%(expressions)s)" + return super().as_sql(compiler, connection, **extra_context) diff --git a/tests/aggregates/tests.py b/tests/aggregates/tests.py index f92be0f..2449c06 100644 --- a/tests/aggregates/tests.py +++ b/tests/aggregates/tests.py @@ -10,6 +10,7 @@ uniqHLL12, uniqTheta, ) +from clickhouse_backend.models.aggregates import ArgMax from .models import WatchSeries @@ -170,6 +171,20 @@ def setUpTestData(cls): # Use bulk_create to insert the list of objects in a single query WatchSeries.objects.bulk_create(watch_series_list) + def test_argMax(self): + result = ( + WatchSeries.objects.values("show") + .annotate(episode=ArgMax("episode", "date_id")) + .order_by("show") + ) + + expected_result = [ + {"show": "Bridgerton", "episode": "S1E1"}, + {"show": "Game of Thrones", "episode": "S1E1"}, + ] + + self.assertQuerysetEqual(result, expected_result, transform=dict) + def _test_uniq(self, cls_uniq): result = ( WatchSeries.objects.values("show", "episode")