Skip to content

Commit 4df6db3

Browse files
authored
Modify bulk update default value (#341)
* Modify bulk update default value * Patch bulk_update_with_default and add tests * Remove unused imports * Change case from default to None
1 parent 0060aa8 commit 4df6db3

File tree

4 files changed

+53
-4
lines changed

4 files changed

+53
-4
lines changed

mssql/functions.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from django.core import validators
88
from django.db import NotSupportedError, connections, transaction
99
from django.db.models import BooleanField, CheckConstraint, Value
10-
from django.db.models.expressions import Case, Exists, Expression, OrderBy, When, Window
10+
from django.db.models.expressions import Case, Exists, OrderBy, When, Window
1111
from django.db.models.fields import BinaryField, Field
1212
from django.db.models.functions import Cast, NthValue, MD5, SHA1, SHA224, SHA256, SHA384, SHA512
1313
from django.db.models.functions.datetime import Now
@@ -294,7 +294,7 @@ def _get_check_sql(self, model, schema_editor):
294294
return sql % tuple(schema_editor.quote_value(p) for p in params)
295295

296296

297-
def bulk_update_with_default(self, objs, fields, batch_size=None, default=0):
297+
def bulk_update_with_default(self, objs, fields, batch_size=None, default=None):
298298
"""
299299
Update the given fields in each of the given objects in the database.
300300
@@ -343,7 +343,8 @@ def bulk_update_with_default(self, objs, fields, batch_size=None, default=0):
343343
attr = Value(attr, output_field=field)
344344
when_statements.append(When(pk=obj.pk, then=attr))
345345
if connection.vendor == 'microsoft' and value_none_counter == len(when_statements):
346-
case_statement = Case(*when_statements, output_field=field, default=Value(default))
346+
# We don't need a case statement if we are setting everything to None
347+
case_statement = Value(None)
347348
else:
348349
case_statement = Case(*when_statements, output_field=field)
349350
if requires_casting:
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Generated by Django 5.0.1 on 2024-01-29 14:18
2+
3+
from django.db import migrations, models
4+
5+
6+
class Migration(migrations.Migration):
7+
8+
dependencies = [
9+
('testapp', '0024_publisher_book'),
10+
]
11+
12+
operations = [
13+
migrations.CreateModel(
14+
name='ModelWithNullableFieldsOfDifferentTypes',
15+
fields=[
16+
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
17+
('int_value', models.IntegerField(null=True)),
18+
('name', models.CharField(max_length=100, null=True)),
19+
('date', models.DateTimeField(null=True)),
20+
],
21+
),
22+
]

testapp/models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ class UUIDModel(models.Model):
5454
def __str__(self):
5555
return self.pk
5656

57+
class ModelWithNullableFieldsOfDifferentTypes(models.Model):
58+
# Issue https://github.com/microsoft/mssql-django/issues/340
59+
# Ensures the integrity of bulk updates with different types
60+
int_value = models.IntegerField(null=True)
61+
name = models.CharField(max_length=100, null=True)
62+
date = models.DateTimeField(null=True)
5763

5864
class TestUniqueNullableModel(models.Model):
5965
# Issue https://github.com/ESSolutions/django-mssql-backend/issues/38:

testapp/tests/test_expressions.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the BSD license.
33

4+
import datetime
45
from unittest import skipUnless
56

67
from django import VERSION
@@ -9,7 +10,8 @@
910
from django.test import TestCase, skipUnlessDBFeature
1011

1112
from django.db.models.aggregates import Count
12-
from ..models import Author, Comment, Post, Editor
13+
from ..models import Author, Comment, Post, Editor, ModelWithNullableFieldsOfDifferentTypes
14+
1315

1416
DJANGO3 = VERSION[0] >= 3
1517

@@ -103,3 +105,21 @@ def test_order_by_nulls_first(self):
103105
self.assertEqual(len(results), 2)
104106
self.assertIsNone(results[0].alt_editor)
105107
self.assertIsNotNone(results[1].alt_editor)
108+
109+
class TestBulkUpdate(TestCase):
110+
def test_bulk_update_different_column_types(self):
111+
data = (
112+
(1, 'a', datetime.datetime(year=2024, month=1, day=1)),
113+
(2, 'b', datetime.datetime(year=2023, month=12, day=31))
114+
)
115+
objs = ModelWithNullableFieldsOfDifferentTypes.objects.bulk_create(ModelWithNullableFieldsOfDifferentTypes(int_value=row_data[0],
116+
name=row_data[1],
117+
date=row_data[2]) for row_data in data)
118+
for obj in objs:
119+
obj.int_value = None
120+
obj.name = None
121+
obj.date = None
122+
ModelWithNullableFieldsOfDifferentTypes.objects.bulk_update(objs, ["int_value", "name", "date"])
123+
self.assertCountEqual(ModelWithNullableFieldsOfDifferentTypes.objects.filter(int_value__isnull=True), objs)
124+
self.assertCountEqual(ModelWithNullableFieldsOfDifferentTypes.objects.filter(name__isnull=True), objs)
125+
self.assertCountEqual(ModelWithNullableFieldsOfDifferentTypes.objects.filter(date__isnull=True), objs)

0 commit comments

Comments
 (0)