diff --git a/src/django_mysql/models/aggregates.py b/src/django_mysql/models/aggregates.py index d4d18d03..dbca4704 100644 --- a/src/django_mysql/models/aggregates.py +++ b/src/django_mysql/models/aggregates.py @@ -1,5 +1,7 @@ from django.db.models import Aggregate, CharField +from django_mysql.models.fields import ListCharField, SetCharField + class BitAnd(Aggregate): function = "BIT_AND" @@ -24,8 +26,12 @@ def __init__( ): if "output_field" not in extra: - # This can/will be improved to SetTextField or ListTextField - extra["output_field"] = CharField() + if separator is not None: + extra["output_field"] = CharField() + elif distinct: + extra["output_field"] = SetCharField(CharField()) + else: + extra["output_field"] = ListCharField(CharField()) super().__init__(expression, **extra) diff --git a/tests/testapp/test_aggregates.py b/tests/testapp/test_aggregates.py index 654fcbda..4574af0c 100644 --- a/tests/testapp/test_aggregates.py +++ b/tests/testapp/test_aggregates.py @@ -74,8 +74,11 @@ def setUp(self): def test_basic_aggregate_ids(self): out = self.shakes.tutees.aggregate(tids=GroupConcat("id")) - concatted_ids = ",".join(self.str_tutee_ids) - assert out == {"tids": concatted_ids} + assert out == {"tids": self.str_tutee_ids} + + def test_distinct_aggregate_ids(self): + out = self.shakes.tutees.aggregate(tids=GroupConcat("id", distinct=True)) + assert out == {"tids": set(self.str_tutee_ids)} def test_basic_annotate_ids(self): concat = GroupConcat("tutees__id") @@ -104,14 +107,14 @@ def test_separator_big(self): def test_expression(self): concat = GroupConcat(F("id") + 1) out = self.shakes.tutees.aggregate(tids=concat) - concatted_ids = ",".join([str(self.jk.id + 1), str(self.grisham.id + 1)]) + concatted_ids = [str(self.jk.id + 1), str(self.grisham.id + 1)] assert out == {"tids": concatted_ids} def test_application_order(self): out = Author.objects.exclude(id=self.shakes.id).aggregate( tids=GroupConcat("tutor_id", distinct=True) ) - assert out == {"tids": str(self.shakes.id)} + assert out == {"tids": {str(self.shakes.id)}} @override_mysql_variables(SQL_MODE="ANSI") def test_separator_ansi_mode(self): @@ -127,11 +130,11 @@ def test_ordering_invalid(self): def test_ordering_asc(self): out = self.shakes.tutees.aggregate(tids=GroupConcat("id", ordering="asc")) - assert out == {"tids": ",".join(self.str_tutee_ids)} + assert out == {"tids": self.str_tutee_ids} def test_ordering_desc(self): out = self.shakes.tutees.aggregate(tids=GroupConcat("id", ordering="desc")) - assert out == {"tids": ",".join(reversed(self.str_tutee_ids))} + assert out == {"tids": list(reversed(self.str_tutee_ids))} def test_separator_ordering(self): concat = GroupConcat("id", separator=":", ordering="asc")