Skip to content

Commit 5c2fa77

Browse files
authored
Fix KeyTransformExact applied to all databases (#98)
* Fixed issue #82
1 parent f4ea0cd commit 5c2fa77

File tree

5 files changed

+145
-14
lines changed

5 files changed

+145
-14
lines changed

mssql/functions.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,16 @@
44
import json
55

66
from django import VERSION
7-
7+
from django.core import validators
88
from django.db import NotSupportedError, connections, transaction
9-
from django.db.models import BooleanField, Value
10-
from django.db.models.functions import Cast, NthValue
11-
from django.db.models.functions.math import ATan2, Log, Ln, Mod, Round
12-
from django.db.models.expressions import Case, Exists, OrderBy, When, Window, Expression
13-
from django.db.models.lookups import Lookup, In
14-
from django.db.models import lookups, CheckConstraint
9+
from django.db.models import BooleanField, CheckConstraint, Value
10+
from django.db.models.expressions import Case, Exists, Expression, OrderBy, When, Window
1511
from django.db.models.fields import BinaryField, Field
16-
from django.db.models.sql.query import Query
12+
from django.db.models.functions import Cast, NthValue
13+
from django.db.models.functions.math import ATan2, Ln, Log, Mod, Round
14+
from django.db.models.lookups import In, Lookup
1715
from django.db.models.query import QuerySet
18-
from django.core import validators
16+
from django.db.models.sql.query import Query
1917

2018
if VERSION >= (3, 1):
2119
from django.db.models.fields.json import (
@@ -67,9 +65,11 @@ def sqlserver_nth_value(self, compiler, connection, **extra_content):
6765
def sqlserver_round(self, compiler, connection, **extra_context):
6866
return self.as_sql(compiler, connection, template='%(function)s(%(expressions)s, 0)', **extra_context)
6967

68+
7069
def sqlserver_random(self, compiler, connection, **extra_context):
7170
return self.as_sql(compiler, connection, function='RAND', **extra_context)
7271

72+
7373
def sqlserver_window(self, compiler, connection, template=None):
7474
# MSSQL window functions require an OVER clause with ORDER BY
7575
if self.order_by is None:
@@ -143,26 +143,29 @@ def split_parameter_list_as_sql(self, compiler, connection):
143143

144144
return in_clause, ()
145145

146+
146147
def unquote_json_rhs(rhs_params):
147148
for value in rhs_params:
148149
value = json.loads(value)
149150
if not isinstance(value, (list, dict)):
150151
rhs_params = [param.replace('"', '') for param in rhs_params]
151152
return rhs_params
152153

154+
153155
def json_KeyTransformExact_process_rhs(self, compiler, connection):
154-
if isinstance(self.rhs, KeyTransform):
155-
return super(lookups.Exact, self).process_rhs(compiler, connection)
156-
rhs, rhs_params = super(KeyTransformExact, self).process_rhs(compiler, connection)
156+
rhs, rhs_params = key_transform_exact_process_rhs(self, compiler, connection)
157+
if connection.vendor == 'microsoft':
158+
rhs_params = unquote_json_rhs(rhs_params)
159+
return rhs, rhs_params
157160

158-
return rhs, unquote_json_rhs(rhs_params)
159161

160162
def json_KeyTransformIn(self, compiler, connection):
161163
lhs, _ = super(KeyTransformIn, self).process_lhs(compiler, connection)
162164
rhs, rhs_params = super(KeyTransformIn, self).process_rhs(compiler, connection)
163165

164166
return (lhs + ' IN ' + rhs, unquote_json_rhs(rhs_params))
165167

168+
166169
def json_HasKeyLookup(self, compiler, connection):
167170
# Process JSON path from the left-hand side.
168171
if isinstance(self.lhs, KeyTransform):
@@ -193,6 +196,7 @@ def json_HasKeyLookup(self, compiler, connection):
193196

194197
return sql % tuple(rhs_params), []
195198

199+
196200
def BinaryField_init(self, *args, **kwargs):
197201
# Add max_length option for BinaryField, default to max
198202
kwargs.setdefault('editable', False)
@@ -202,6 +206,7 @@ def BinaryField_init(self, *args, **kwargs):
202206
else:
203207
self.max_length = 'max'
204208

209+
205210
def _get_check_sql(self, model, schema_editor):
206211
if VERSION >= (3, 1):
207212
query = Query(model=model, alias_cols=False)
@@ -217,6 +222,7 @@ def _get_check_sql(self, model, schema_editor):
217222

218223
return sql % tuple(schema_editor.quote_value(p) for p in params)
219224

225+
220226
def bulk_update_with_default(self, objs, fields, batch_size=None, default=0):
221227
"""
222228
Update the given fields in each of the given objects in the database.
@@ -255,7 +261,7 @@ def bulk_update_with_default(self, objs, fields, batch_size=None, default=0):
255261
attr = getattr(obj, field.attname)
256262
if not isinstance(attr, Expression):
257263
if attr is None:
258-
value_none_counter+=1
264+
value_none_counter += 1
259265
attr = Value(attr, output_field=field)
260266
when_statements.append(When(pk=obj.pk, then=attr))
261267
if(value_none_counter == len(when_statements)):
@@ -272,10 +278,13 @@ def bulk_update_with_default(self, objs, fields, batch_size=None, default=0):
272278
rows_updated += self.filter(pk__in=pks).update(**update_kwargs)
273279
return rows_updated
274280

281+
275282
ATan2.as_microsoft = sqlserver_atan2
276283
In.split_parameter_list_as_sql = split_parameter_list_as_sql
277284
if VERSION >= (3, 1):
278285
KeyTransformIn.as_microsoft = json_KeyTransformIn
286+
# Need copy of old KeyTransformExact.process_rhs to call later
287+
key_transform_exact_process_rhs = KeyTransformExact.process_rhs
279288
KeyTransformExact.process_rhs = json_KeyTransformExact_process_rhs
280289
HasKeyLookup.as_microsoft = json_HasKeyLookup
281290
Ln.as_microsoft = sqlserver_ln
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Generated by Django 4.0.1 on 2022-02-01 15:58
2+
3+
from django import VERSION
4+
from django.db import migrations, models
5+
6+
7+
class Migration(migrations.Migration):
8+
9+
dependencies = [
10+
('testapp', '0015_test_rename_m2mfield_part2'),
11+
]
12+
13+
# JSONField added in Django 3.1
14+
if VERSION >= (3, 1):
15+
operations = [
16+
migrations.CreateModel(
17+
name='JSONModel',
18+
fields=[
19+
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
20+
('value', models.JSONField()),
21+
],
22+
options={
23+
'required_db_features': {'supports_json_field'},
24+
},
25+
),
26+
]
27+
else:
28+
pass

testapp/models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import uuid
55

6+
from django import VERSION
67
from django.db import models
78
from django.db.models import Q
89
from django.utils import timezone
@@ -151,3 +152,11 @@ class Meta:
151152

152153
_type = models.CharField(max_length=50)
153154
status = models.CharField(max_length=50)
155+
156+
157+
if VERSION >= (3, 1):
158+
class JSONModel(models.Model):
159+
value = models.JSONField()
160+
161+
class Meta:
162+
required_db_features = {'supports_json_field'}

testapp/settings.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the BSD license.
33
import os
4+
from pathlib import Path
5+
6+
from django import VERSION
7+
8+
BASE_DIR = Path(__file__).resolve().parent.parent
49

510
DATABASES = {
611
"default": {
@@ -23,6 +28,14 @@
2328
},
2429
}
2530

31+
# Django 3.0 and below unit test doesn't handle more than 2 databases in DATABASES correctly
32+
if VERSION >= (3, 1):
33+
DATABASES['sqlite'] = {
34+
"ENGINE": "django.db.backends.sqlite3",
35+
"NAME": str(BASE_DIR / "db.sqlitetest"),
36+
}
37+
38+
2639
# Set to `True` locally if you want SQL queries logged to django_sql.log
2740
DEBUG = False
2841

@@ -267,6 +280,7 @@
267280
'backends.tests.BackendTestCase.test_queries_logger',
268281
'migrations.test_operations.OperationTests.test_alter_field_pk_mti_fk',
269282
'migrations.test_operations.OperationTests.test_run_sql_add_missing_semicolon_on_collect_sql',
283+
'migrations.test_operations.OperationTests.test_alter_field_pk_mti_and_fk_to_base'
270284
]
271285

272286
REGEX_TESTS = [

testapp/tests/test_jsonfield.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the BSD license.
3+
4+
from unittest import skipUnless
5+
6+
from django import VERSION
7+
from django.test import TestCase
8+
9+
if VERSION >= (3, 1):
10+
from ..models import JSONModel
11+
12+
13+
def _check_jsonfield_supported_sqlite():
14+
# Info about JSONField support in SQLite: https://code.djangoproject.com/wiki/JSON1Extension
15+
import sqlite3
16+
17+
supports_jsonfield = True
18+
try:
19+
conn = sqlite3.connect(':memory:')
20+
cursor = conn.cursor()
21+
cursor.execute('SELECT JSON(\'{"a": "b"}\')')
22+
except sqlite3.OperationalError:
23+
supports_jsonfield = False
24+
finally:
25+
return supports_jsonfield
26+
27+
28+
class TestJSONField(TestCase):
29+
databases = ['default']
30+
# Django 3.0 and below unit test doesn't handle more than 2 databases in DATABASES correctly
31+
if VERSION >= (3, 1):
32+
databases.append('sqlite')
33+
34+
json = {
35+
'a': 'b',
36+
'b': 1,
37+
'c': '1',
38+
'd': [],
39+
'e': [1, 2],
40+
'f': ['a', 'b'],
41+
'g': [1, 'a'],
42+
'h': {},
43+
'i': {'j': 1},
44+
'j': False,
45+
'k': True,
46+
'l': {
47+
'foo': 'bar',
48+
'baz': {'a': 'b', 'c': 'd'},
49+
'bar': ['foo', 'bar'],
50+
'bax': {'foo': 'bar'},
51+
},
52+
}
53+
54+
@skipUnless(VERSION >= (3, 1), "JSONField not support in Django versions < 3.1")
55+
@skipUnless(
56+
_check_jsonfield_supported_sqlite(),
57+
"JSONField not support by SQLite on this platform and Python version",
58+
)
59+
def test_keytransformexact_not_overriding(self):
60+
# Issue https://github.com/microsoft/mssql-django/issues/82
61+
json_obj = JSONModel(value=self.json)
62+
json_obj.save()
63+
self.assertSequenceEqual(
64+
JSONModel.objects.filter(value__a='b'),
65+
[json_obj],
66+
)
67+
json_obj.save(using='sqlite')
68+
self.assertSequenceEqual(
69+
JSONModel.objects.using('sqlite').filter(value__a='b'),
70+
[json_obj],
71+
)

0 commit comments

Comments
 (0)