Skip to content

Commit 6af4bce

Browse files
authored
Fix Exists in Case(When) (#29)
1 parent 968cfc8 commit 6af4bce

File tree

4 files changed

+54
-13
lines changed

4 files changed

+54
-13
lines changed

sql_server/pyodbc/compiler.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import django
55
from django.db.models.aggregates import Avg, Count, StdDev, Variance
6-
from django.db.models.expressions import Exists, OrderBy, Ref, Subquery, Value
6+
from django.db.models.expressions import OrderBy, Ref, Subquery, Value
77
from django.db.models.functions import (
88
Chr, ConcatPair, Greatest, Least, Length, LPad, Repeat, RPad, StrIndex, Substr, Trim
99
)
@@ -70,15 +70,6 @@ def _as_sql_lpad(self, compiler, connection):
7070
return template % {'expression': expression, 'length': length, 'fill_text': fill_text}, params
7171

7272

73-
def _as_sql_exists(self, compiler, connection, template=None, **extra_context):
74-
# MS SQL doesn't allow EXISTS() in the SELECT list, so wrap it with a
75-
# CASE WHEN expression. Change the template since the When expression
76-
# requires a left hand side (column) to compare against.
77-
sql, params = self.as_sql(compiler, connection, template, **extra_context)
78-
sql = 'CASE WHEN {} THEN 1 ELSE 0 END'.format(sql)
79-
return sql, params
80-
81-
8273
def _as_sql_order_by(self, compiler, connection):
8374
template = None
8475
if self.nulls_last:
@@ -399,8 +390,6 @@ def _as_microsoft(self, node):
399390
as_microsoft = _as_sql_rpad
400391
elif isinstance(node, LPad):
401392
as_microsoft = _as_sql_lpad
402-
elif isinstance(node, Exists):
403-
as_microsoft = _as_sql_exists
404393
elif isinstance(node, OrderBy):
405394
as_microsoft = _as_sql_order_by
406395
elif isinstance(node, Repeat):

sql_server/pyodbc/features.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
2121
ignores_quoted_identifier_case = True
2222
requires_literal_defaults = True
2323
requires_sqlparse_for_splitting = False
24+
supports_boolean_expr_in_select_clause = False
2425
supports_ignore_conflicts = False
2526
supports_index_on_text_field = False
2627
supports_paramstyle_pyformat = False

sql_server/pyodbc/functions.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1+
from django import VERSION
2+
from django.db.models import BooleanField
13
from django.db.models.functions import Cast
24
from django.db.models.functions.math import ATan2, Log, Ln, Round
5+
from django.db.models.expressions import Case, Exists, When
6+
from django.db.models.lookups import Lookup
7+
8+
DJANGO3 = VERSION[0] >= 3
39

410

511
class TryCast(Cast):
@@ -24,7 +30,35 @@ def sqlserver_round(self, compiler, connection, **extra_context):
2430
return self.as_sql(compiler, connection, template='%(function)s(%(expressions)s, 0)', **extra_context)
2531

2632

33+
def sqlserver_exists(self, compiler, connection, template=None, **extra_context):
34+
# MS SQL doesn't allow EXISTS() in the SELECT list, so wrap it with a
35+
# CASE WHEN expression. Change the template since the When expression
36+
# requires a left hand side (column) to compare against.
37+
sql, params = self.as_sql(compiler, connection, template, **extra_context)
38+
sql = 'CASE WHEN {} THEN 1 ELSE 0 END'.format(sql)
39+
return sql, params
40+
41+
42+
def sqlserver_lookup(self, compiler, connection):
43+
# MSSQL doesn't allow EXISTS() to be compared to another expression
44+
# unless it's wrapped in a CASE WHEN.
45+
wrapped = False
46+
exprs = []
47+
for expr in (self.lhs, self.rhs):
48+
if isinstance(expr, Exists):
49+
expr = Case(When(expr, then=True), default=False, output_field=BooleanField())
50+
wrapped = True
51+
exprs.append(expr)
52+
lookup = type(self)(*exprs) if wrapped else self
53+
return lookup.as_sql(compiler, connection)
54+
55+
2756
ATan2.as_microsoft = sqlserver_atan2
2857
Log.as_microsoft = sqlserver_log
2958
Ln.as_microsoft = sqlserver_ln
3059
Round.as_microsoft = sqlserver_round
60+
61+
if DJANGO3:
62+
Lookup.as_microsoft = sqlserver_lookup
63+
else:
64+
Exists.as_microsoft = sqlserver_exists

testapp/tests/test_expressions.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1-
from django.db.models.expressions import Exists, OuterRef, Subquery
1+
from unittest import skipUnless
2+
3+
from django import VERSION
4+
from django.db.models import IntegerField
5+
from django.db.models.expressions import Case, Exists, OuterRef, Subquery, Value, When
26
from django.db.utils import IntegrityError
37
from django.test import TestCase, skipUnlessDBFeature
48

59
from ..models import Author, Comment, Editor, Post
610

11+
DJANGO3 = VERSION[0] >= 3
12+
713

814
class TestSubquery(TestCase):
915
def setUp(self):
@@ -27,6 +33,17 @@ def test_with_count(self):
2733
post_exists=Exists(Post.objects.all())
2834
).filter(post_exists=True).count()
2935

36+
@skipUnless(DJANGO3, "Django 3 specific tests")
37+
def test_with_case_when(self):
38+
author = Author.objects.annotate(
39+
has_post=Case(
40+
When(Exists(Post.objects.filter(author=OuterRef('pk')).values('pk')), then=Value(1)),
41+
default=Value(0),
42+
output_field=IntegerField(),
43+
)
44+
).get()
45+
self.assertEqual(author.has_post, 1)
46+
3047

3148
@skipUnlessDBFeature('supports_partially_nullable_unique_constraints')
3249
class TestPartiallyNullableUniqueTogether(TestCase):

0 commit comments

Comments
 (0)