Skip to content

Commit adc01a5

Browse files
authored
Fix aggregate queries with case expressions (#354)
* Fix aggregate queries with case expressions
1 parent 0899188 commit adc01a5

File tree

3 files changed

+67
-5
lines changed

3 files changed

+67
-5
lines changed

mssql/base.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import time
1010
import struct
1111
import datetime
12+
from decimal import Decimal
13+
from uuid import UUID
1214

1315
from django.core.exceptions import ImproperlyConfigured
1416
from django.utils.functional import cached_property
@@ -571,6 +573,36 @@ def __init__(self, cursor, connection):
571573
self.last_sql = ''
572574
self.last_params = ()
573575

576+
def _as_sql_type(self, typ, value):
577+
if isinstance(value, str):
578+
length = len(value)
579+
if length == 0:
580+
return 'NVARCHAR'
581+
elif length > 4000:
582+
return 'NVARCHAR(max)'
583+
return 'NVARCHAR(%s)' % len(value)
584+
elif typ == int:
585+
if value < 0x7FFFFFFF and value > -0x7FFFFFFF:
586+
return 'INT'
587+
else:
588+
return 'BIGINT'
589+
elif typ == float:
590+
return 'DOUBLE PRECISION'
591+
elif typ == bool:
592+
return 'BIT'
593+
elif isinstance(value, Decimal):
594+
return 'NUMERIC'
595+
elif isinstance(value, datetime.datetime):
596+
return 'DATETIME2'
597+
elif isinstance(value, datetime.date):
598+
return 'DATE'
599+
elif isinstance(value, datetime.time):
600+
return 'TIME'
601+
elif isinstance(value, UUID):
602+
return 'uniqueidentifier'
603+
else:
604+
raise NotImplementedError('Not supported type %s (%s)' % (type(value), repr(value)))
605+
574606
def close(self):
575607
if self.active:
576608
self.active = False
@@ -588,6 +620,27 @@ def format_sql(self, sql, params):
588620

589621
return sql
590622

623+
def format_group_by_params(self, query, params):
624+
if params:
625+
# Insert None params directly into the query
626+
if None in params:
627+
null_params = ['NULL' if param is None else '%s' for param in params]
628+
query = query % tuple(null_params)
629+
params = tuple(p for p in params if p is not None)
630+
params = [(param, type(param)) for param in params]
631+
params_dict = {param: '@var%d' % i for i, param in enumerate(set(params))}
632+
args = [params_dict[param] for param in params]
633+
634+
variables = []
635+
params = []
636+
for key, value in params_dict.items():
637+
datatype = self._as_sql_type(key[1], key[0])
638+
variables.append("%s %s = %%s " % (value, datatype))
639+
params.append(key[0])
640+
query = ('DECLARE %s \n' % ','.join(variables)) + (query % tuple(args))
641+
642+
return query, params
643+
591644
def format_params(self, params):
592645
fp = []
593646
if params is not None:
@@ -616,6 +669,8 @@ def format_params(self, params):
616669

617670
def execute(self, sql, params=None):
618671
self.last_sql = sql
672+
if 'GROUP BY' in sql:
673+
sql, params = self.format_group_by_params(sql, params)
619674
sql = self.format_sql(sql, params)
620675
params = self.format_params(params)
621676
self.last_params = params

testapp/settings.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,6 @@
284284
'aggregation.test_filter_argument.FilteredAggregateTests.test_filtered_aggregate_ref_subquery_annotation',
285285
'aggregation.tests.AggregateAnnotationPruningTests.test_referenced_group_by_annotation_kept',
286286
'aggregation.tests.AggregateAnnotationPruningTests.test_referenced_window_requires_wrapping',
287-
'aggregation.tests.AggregateAnnotationPruningTests.test_unused_aliased_aggregate_and_annotation_reverse_fk',
288-
'aggregation.tests.AggregateAnnotationPruningTests.test_unused_aliased_aggregate_and_annotation_reverse_fk_grouped',
289287
'aggregation.tests.AggregateTestCase.test_group_by_nested_expression_with_params',
290288
'expressions.tests.BasicExpressionsTests.test_aggregate_subquery_annotation',
291289
'queries.test_qs_combinators.QuerySetSetOperationTests.test_union_order_with_null_first_last',

testapp/tests/test_expressions.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
from unittest import skipUnless
66

77
from django import VERSION
8-
from django.db.models import IntegerField, F
8+
from django.db.models import CharField, IntegerField, F
99
from django.db.models.expressions import Case, Exists, OuterRef, Subquery, Value, When, ExpressionWrapper
1010
from django.test import TestCase, skipUnlessDBFeature
1111

12-
from django.db.models.aggregates import Count
13-
from ..models import Author, Comment, Post, Editor, ModelWithNullableFieldsOfDifferentTypes
12+
from django.db.models.aggregates import Count, Sum
13+
14+
from ..models import Author, Book, Comment, Post, Editor, ModelWithNullableFieldsOfDifferentTypes
1415

1516

1617
DJANGO3 = VERSION[0] >= 3
@@ -85,6 +86,14 @@ def test_order_by_exists(self):
8586
self.assertSequenceEqual(authors_by_posts, [author_without_posts, self.author])
8687

8788

89+
class TestGroupBy(TestCase):
90+
def test_group_by_case(self):
91+
annotated_queryset = Book.objects.annotate(age=Case(
92+
When(id__gt=1000, then=Value("new")),
93+
default=Value("old"),
94+
output_field=CharField())).values('age').annotate(sum=Sum('id'))
95+
self.assertEqual(list(annotated_queryset.all()), [])
96+
8897
@skipUnless(DJANGO3, "Django 3 specific tests")
8998
@skipUnlessDBFeature("order_by_nulls_first")
9099
class TestOrderBy(TestCase):

0 commit comments

Comments
 (0)