Skip to content

Commit d01c01e

Browse files
authored
Fixed overridden functions not working with other DBs (#105)
Fixes issue #92 and other potential issues caused by overriding Django functions in functions.py.
1 parent 9450ca1 commit d01c01e

File tree

4 files changed

+171
-5
lines changed

4 files changed

+171
-5
lines changed

mssql/functions.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,13 @@ def sqlserver_orderby(self, compiler, connection):
125125

126126

127127
def split_parameter_list_as_sql(self, compiler, connection):
128+
if connection.vendor == 'microsoft':
129+
return mssql_split_parameter_list_as_sql(self, compiler, connection)
130+
else:
131+
return in_split_parameter_list_as_sql(self, compiler, connection)
132+
133+
134+
def mssql_split_parameter_list_as_sql(self, compiler, connection):
128135
# Insert In clause parameters 1000 at a time into a temp table.
129136
lhs, _ = self.process_lhs(compiler, connection)
130137
_, rhs_params = self.batch_process_rhs(compiler, connection)
@@ -215,10 +222,12 @@ def _get_check_sql(self, model, schema_editor):
215222
where = query.build_where(self.check)
216223
compiler = query.get_compiler(connection=schema_editor.connection)
217224
sql, params = where.as_sql(compiler, schema_editor.connection)
218-
try:
219-
for p in params: str(p).encode('ascii')
220-
except UnicodeEncodeError:
221-
sql = sql.replace('%s', 'N%s')
225+
if schema_editor.connection.vendor == 'microsoft':
226+
try:
227+
for p in params:
228+
str(p).encode('ascii')
229+
except UnicodeEncodeError:
230+
sql = sql.replace('%s', 'N%s')
222231

223232
return sql % tuple(schema_editor.quote_value(p) for p in params)
224233

@@ -264,7 +273,7 @@ def bulk_update_with_default(self, objs, fields, batch_size=None, default=0):
264273
value_none_counter += 1
265274
attr = Value(attr, output_field=field)
266275
when_statements.append(When(pk=obj.pk, then=attr))
267-
if(value_none_counter == len(when_statements)):
276+
if connections[self.db].vendor == 'microsoft' and value_none_counter == len(when_statements):
268277
case_statement = Case(*when_statements, output_field=field, default=Value(default))
269278
else:
270279
case_statement = Case(*when_statements, output_field=field)
@@ -280,6 +289,8 @@ def bulk_update_with_default(self, objs, fields, batch_size=None, default=0):
280289

281290

282291
ATan2.as_microsoft = sqlserver_atan2
292+
# Need copy of old In.split_parameter_list_as_sql for other backends to call
293+
in_split_parameter_list_as_sql = In.split_parameter_list_as_sql
283294
In.split_parameter_list_as_sql = split_parameter_list_as_sql
284295
if VERSION >= (3, 1):
285296
KeyTransformIn.as_microsoft = json_KeyTransformIn
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Generated by Django 4.0.2 on 2022-02-23 19:06
2+
3+
from django import VERSION
4+
from django.db import migrations, models
5+
6+
7+
class Migration(migrations.Migration):
8+
9+
dependencies = [
10+
('testapp', '0016_jsonmodel'),
11+
]
12+
13+
operations = [
14+
migrations.CreateModel(
15+
name='BinaryData',
16+
fields=[
17+
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
18+
('binary', models.BinaryField(max_length='max', null=True)),
19+
],
20+
),
21+
]
22+
23+
if VERSION >= (3, 2):
24+
operations += [
25+
migrations.CreateModel(
26+
name='TestCheckConstraintWithUnicode',
27+
fields=[
28+
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
29+
('name', models.CharField(max_length=100)),
30+
],
31+
options={
32+
'required_db_features': {'supports_table_check_constraints'},
33+
},
34+
),
35+
migrations.AddConstraint(
36+
model_name='testcheckconstraintwithunicode',
37+
constraint=models.CheckConstraint(check=models.Q(('name__startswith', '÷'), _negated=True), name='name_does_not_starts_with_÷'),
38+
),
39+
]

testapp/models.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,29 @@ class Meta:
154154
status = models.CharField(max_length=50)
155155

156156

157+
class BinaryData(models.Model):
158+
binary = models.BinaryField(null=True)
159+
160+
157161
if VERSION >= (3, 1):
158162
class JSONModel(models.Model):
159163
value = models.JSONField()
160164

161165
class Meta:
162166
required_db_features = {'supports_json_field'}
167+
168+
169+
if VERSION >= (3, 2):
170+
class TestCheckConstraintWithUnicode(models.Model):
171+
name = models.CharField(max_length=100)
172+
173+
class Meta:
174+
required_db_features = {
175+
'supports_table_check_constraints',
176+
}
177+
constraints = [
178+
models.CheckConstraint(
179+
check=~models.Q(name__startswith='\u00f7'),
180+
name='name_does_not_starts_with_\u00f7',
181+
)
182+
]
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the BSD license.
3+
4+
from unittest import skipUnless
5+
6+
from django import VERSION
7+
from django.db import OperationalError
8+
from django.db.backends.sqlite3.operations import DatabaseOperations
9+
from django.test import TestCase, skipUnlessDBFeature
10+
11+
from ..models import BinaryData, Pizza, Topping
12+
13+
if VERSION >= (3, 2):
14+
from ..models import TestCheckConstraintWithUnicode
15+
16+
17+
@skipUnless(
18+
VERSION >= (3, 1),
19+
"Django 3.0 and below doesn't support different databases in unit tests",
20+
)
21+
class TestMultpleDatabases(TestCase):
22+
databases = ['default', 'sqlite']
23+
24+
def test_in_split_parameter_list_as_sql(self):
25+
# Issue: https://github.com/microsoft/mssql-django/issues/92
26+
27+
# Mimic databases that have a limit on parameters (e.g. Oracle DB)
28+
old_max_in_list_size = DatabaseOperations.max_in_list_size
29+
DatabaseOperations.max_in_list_size = lambda self: 100
30+
31+
mssql_iterations = 3000
32+
Pizza.objects.bulk_create([Pizza() for _ in range(mssql_iterations)])
33+
Topping.objects.bulk_create([Topping() for _ in range(mssql_iterations)])
34+
prefetch_result = Pizza.objects.prefetch_related('toppings')
35+
self.assertEqual(len(prefetch_result), mssql_iterations)
36+
37+
# Different iterations since SQLite has max host parameters of 999 for versions prior to 3.32.0
38+
# Info about limit: https://www.sqlite.org/limits.html
39+
sqlite_iterations = 999
40+
Pizza.objects.using('sqlite').bulk_create([Pizza() for _ in range(sqlite_iterations)])
41+
Topping.objects.using('sqlite').bulk_create([Topping() for _ in range(sqlite_iterations)])
42+
prefetch_result_sqlite = Pizza.objects.using('sqlite').prefetch_related('toppings')
43+
self.assertEqual(len(prefetch_result_sqlite), sqlite_iterations)
44+
45+
DatabaseOperations.max_in_list_size = old_max_in_list_size
46+
47+
def test_binaryfield_init(self):
48+
binary_data = b'\x00\x46\xFE'
49+
binary = BinaryData(binary=binary_data)
50+
binary.save()
51+
binary.save(using='sqlite')
52+
53+
try:
54+
binary.full_clean()
55+
except ValidationError:
56+
self.fail()
57+
58+
b1 = BinaryData.objects.filter(binary=binary_data)
59+
self.assertSequenceEqual(
60+
b1,
61+
[binary],
62+
)
63+
b2 = BinaryData.objects.using('sqlite').filter(binary=binary_data)
64+
self.assertSequenceEqual(
65+
b2,
66+
[binary],
67+
)
68+
69+
@skipUnlessDBFeature('supports_table_check_constraints')
70+
@skipUnless(
71+
VERSION >= (3, 2),
72+
"Django 3.1 and below has errors from running migrations for this test",
73+
)
74+
def test_checkconstraint_get_check_sql(self):
75+
TestCheckConstraintWithUnicode.objects.create(name='abc')
76+
try:
77+
TestCheckConstraintWithUnicode.objects.using('sqlite').create(name='abc')
78+
except OperationalError:
79+
self.fail()
80+
81+
def test_queryset_bulk_update(self):
82+
objs = [
83+
BinaryData.objects.create(binary=b'\x00') for _ in range(5)
84+
]
85+
for obj in objs:
86+
obj.binary = None
87+
BinaryData.objects.bulk_update(objs, ["binary"])
88+
self.assertCountEqual(BinaryData.objects.filter(binary__isnull=True), objs)
89+
90+
objs = [
91+
BinaryData.objects.using('sqlite').create(binary=b'\x00') for _ in range(5)
92+
]
93+
for obj in objs:
94+
obj.binary = None
95+
BinaryData.objects.using('sqlite').bulk_update(objs, ["binary"])
96+
self.assertCountEqual(BinaryData.objects.using('sqlite').filter(binary__isnull=True), objs)

0 commit comments

Comments
 (0)