diff --git a/README.rst b/README.rst index 52d93001..a9422c0d 100644 --- a/README.rst +++ b/README.rst @@ -7,6 +7,12 @@ django-pyodbc-azure .. image:: http://img.shields.io/pypi/l/django-pyodbc-azure.svg?style=flat :target: http://opensource.org/licenses/BSD-3-Clause +.. image:: https://ci.appveyor.com/api/projects/status/i9hfnl2gfeiq82qb?svg=true + :target: https://ci.appveyor.com/project/denisenkom/django-pyodbc-azure + +.. image:: https://codecov.io/gh/denisenkom/django-pyodbc-azure/branch/master/graph/badge.svg +  :target: https://codecov.io/gh/denisenkom/django-pyodbc-azure + *django-pyodbc-azure* is a modern fork of `django-pyodbc `__, a `Django `__ Microsoft SQL Server external diff --git a/appveyor.yml b/appveyor.yml new file mode 100644 index 00000000..12fc7767 --- /dev/null +++ b/appveyor.yml @@ -0,0 +1,52 @@ +version: 1.0.{build} + +os: Windows Server 2012 R2 + +environment: + HOST: localhost + SQLUSER: sa + SQLPASSWORD: Password12! + DATABASE: test + matrix: + - PYTHON: "C:\\Python36" + DJANGOVER: 1.11.3 + SQLINSTANCE: SQL2016 + - PYTHON: "C:\\Python36" + DJANGOVER: 1.10.7 + SQLINSTANCE: SQL2016 + - PYTHON: "C:\\Python36" + DJANGOVER: 1.9.13 + SQLINSTANCE: SQL2016 + #- PYTHON: "C:\\Python36" + # DJANGOVER: 1.11.3 + # SQLINSTANCE: SQL2014 + - PYTHON: "C:\\Python36" + DJANGOVER: 1.11.3 + SQLINSTANCE: SQL2012SP1 + +install: + - "SET PATH=%PYTHON%;%PYTHON%\\Scripts;%PATH%" + - python --version + - "python -c \"import struct; print(struct.calcsize('P') * 8)\"" + - pip install django==%DJANGOVER% + - pip install enum34 + - pip install python-memcached <= 1.53 + - pip install mock codecov + - pip install -e . + +build_script: + - python setup.py sdist + +before_test: + # setup SQL Server + - ps: | + $instanceName = $env:SQLINSTANCE + Start-Service "MSSQL`$$instanceName" + Start-Service "SQLBrowser" + - sqlcmd -S "(local)\%SQLINSTANCE%" -Q "Use [master]; CREATE DATABASE test;" + - sqlcmd -S "(local)\%SQLINSTANCE%" -h -1 -Q "set nocount on; Select @@version" + + +test_script: + - coverage run tests/runtests.py --noinput --settings=test_mssql --debug-sql + - codecov diff --git a/tests/aggregation/__init__.py b/tests/aggregation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/aggregation/models.py b/tests/aggregation/models.py new file mode 100644 index 00000000..fd441fe5 --- /dev/null +++ b/tests/aggregation/models.py @@ -0,0 +1,44 @@ +from django.db import models + + +class Author(models.Model): + name = models.CharField(max_length=100) + age = models.IntegerField() + friends = models.ManyToManyField('self', blank=True) + + def __str__(self): + return self.name + + +class Publisher(models.Model): + name = models.CharField(max_length=255) + num_awards = models.IntegerField() + duration = models.DurationField(blank=True, null=True) + + def __str__(self): + return self.name + + +class Book(models.Model): + isbn = models.CharField(max_length=9) + name = models.CharField(max_length=255) + pages = models.IntegerField() + rating = models.FloatField() + price = models.DecimalField(decimal_places=2, max_digits=6) + authors = models.ManyToManyField(Author) + contact = models.ForeignKey(Author, models.CASCADE, related_name='book_contact_set') + publisher = models.ForeignKey(Publisher, models.CASCADE) + pubdate = models.DateField() + + def __str__(self): + return self.name + + +class Store(models.Model): + name = models.CharField(max_length=255) + books = models.ManyToManyField(Book) + original_opening = models.DateTimeField() + friday_night_closing = models.TimeField() + + def __str__(self): + return self.name diff --git a/tests/aggregation/test_filter_argument.py b/tests/aggregation/test_filter_argument.py new file mode 100644 index 00000000..54836178 --- /dev/null +++ b/tests/aggregation/test_filter_argument.py @@ -0,0 +1,81 @@ +import datetime +from decimal import Decimal + +from django.db.models import Case, Count, F, Q, Sum, When +from django.test import TestCase + +from .models import Author, Book, Publisher + + +class FilteredAggregateTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.a1 = Author.objects.create(name='test', age=40) + cls.a2 = Author.objects.create(name='test2', age=60) + cls.a3 = Author.objects.create(name='test3', age=100) + cls.p1 = Publisher.objects.create(name='Apress', num_awards=3, duration=datetime.timedelta(days=1)) + cls.b1 = Book.objects.create( + isbn='159059725', name='The Definitive Guide to Django: Web Development Done Right', + pages=447, rating=4.5, price=Decimal('30.00'), contact=cls.a1, publisher=cls.p1, + pubdate=datetime.date(2007, 12, 6), + ) + cls.b2 = Book.objects.create( + isbn='067232959', name='Sams Teach Yourself Django in 24 Hours', + pages=528, rating=3.0, price=Decimal('23.09'), contact=cls.a2, publisher=cls.p1, + pubdate=datetime.date(2008, 3, 3), + ) + cls.b3 = Book.objects.create( + isbn='159059996', name='Practical Django Projects', + pages=600, rating=4.5, price=Decimal('29.69'), contact=cls.a3, publisher=cls.p1, + pubdate=datetime.date(2008, 6, 23), + ) + cls.a1.friends.add(cls.a2) + cls.a1.friends.add(cls.a3) + cls.b1.authors.add(cls.a1) + cls.b1.authors.add(cls.a3) + cls.b2.authors.add(cls.a2) + cls.b3.authors.add(cls.a3) + + def test_filtered_aggregates(self): + agg = Sum('age', filter=Q(name__startswith='test')) + self.assertEqual(Author.objects.aggregate(age=agg)['age'], 200) + + def test_double_filtered_aggregates(self): + agg = Sum('age', filter=Q(Q(name='test2') & ~Q(name='test'))) + self.assertEqual(Author.objects.aggregate(age=agg)['age'], 60) + + def test_excluded_aggregates(self): + agg = Sum('age', filter=~Q(name='test2')) + self.assertEqual(Author.objects.aggregate(age=agg)['age'], 140) + + def test_related_aggregates_m2m(self): + agg = Sum('friends__age', filter=~Q(friends__name='test')) + self.assertEqual(Author.objects.filter(name='test').aggregate(age=agg)['age'], 160) + + def test_related_aggregates_m2m_and_fk(self): + q = Q(friends__book__publisher__name='Apress') & ~Q(friends__name='test3') + agg = Sum('friends__book__pages', filter=q) + self.assertEqual(Author.objects.filter(name='test').aggregate(pages=agg)['pages'], 528) + + def test_plain_annotate(self): + agg = Sum('book__pages', filter=Q(book__rating__gt=3)) + qs = Author.objects.annotate(pages=agg).order_by('pk') + self.assertSequenceEqual([a.pages for a in qs], [447, None, 1047]) + + def test_filtered_aggregate_on_annotate(self): + pages_annotate = Sum('book__pages', filter=Q(book__rating__gt=3)) + age_agg = Sum('age', filter=Q(total_pages__gte=400)) + aggregated = Author.objects.annotate(total_pages=pages_annotate).aggregate(summed_age=age_agg) + self.assertEqual(aggregated, {'summed_age': 140}) + + def test_case_aggregate(self): + agg = Sum( + Case(When(friends__age=40, then=F('friends__age'))), + filter=Q(friends__name__startswith='test'), + ) + self.assertEqual(Author.objects.aggregate(age=agg)['age'], 80) + + def test_sum_star_exception(self): + msg = 'Star cannot be used with filter. Please specify a field.' + with self.assertRaisesMessage(ValueError, msg): + Count('*', filter=Q(age=40)) diff --git a/tests/aggregation/tests.py b/tests/aggregation/tests.py new file mode 100644 index 00000000..4572e2a8 --- /dev/null +++ b/tests/aggregation/tests.py @@ -0,0 +1,1109 @@ +import datetime +import re +from decimal import Decimal + +from django.core.exceptions import FieldError +from django.db import connection +from django.db.models import ( + Avg, Count, DecimalField, DurationField, F, FloatField, Func, IntegerField, + Max, Min, Sum, Value, +) +from django.test import TestCase +from django.test.utils import Approximate, CaptureQueriesContext +from django.utils import timezone + +from .models import Author, Book, Publisher, Store + + +class AggregateTestCase(TestCase): + + @classmethod + def setUpTestData(cls): + cls.a1 = Author.objects.create(name='Adrian Holovaty', age=34) + cls.a2 = Author.objects.create(name='Jacob Kaplan-Moss', age=35) + cls.a3 = Author.objects.create(name='Brad Dayley', age=45) + cls.a4 = Author.objects.create(name='James Bennett', age=29) + cls.a5 = Author.objects.create(name='Jeffrey Forcier', age=37) + cls.a6 = Author.objects.create(name='Paul Bissex', age=29) + cls.a7 = Author.objects.create(name='Wesley J. Chun', age=25) + cls.a8 = Author.objects.create(name='Peter Norvig', age=57) + cls.a9 = Author.objects.create(name='Stuart Russell', age=46) + cls.a1.friends.add(cls.a2, cls.a4) + cls.a2.friends.add(cls.a1, cls.a7) + cls.a4.friends.add(cls.a1) + cls.a5.friends.add(cls.a6, cls.a7) + cls.a6.friends.add(cls.a5, cls.a7) + cls.a7.friends.add(cls.a2, cls.a5, cls.a6) + cls.a8.friends.add(cls.a9) + cls.a9.friends.add(cls.a8) + + cls.p1 = Publisher.objects.create(name='Apress', num_awards=3, duration=datetime.timedelta(days=1)) + cls.p2 = Publisher.objects.create(name='Sams', num_awards=1, duration=datetime.timedelta(days=2)) + cls.p3 = Publisher.objects.create(name='Prentice Hall', num_awards=7) + cls.p4 = Publisher.objects.create(name='Morgan Kaufmann', num_awards=9) + cls.p5 = Publisher.objects.create(name="Jonno's House of Books", num_awards=0) + + cls.b1 = Book.objects.create( + isbn='159059725', name='The Definitive Guide to Django: Web Development Done Right', + pages=447, rating=4.5, price=Decimal('30.00'), contact=cls.a1, publisher=cls.p1, + pubdate=datetime.date(2007, 12, 6) + ) + cls.b2 = Book.objects.create( + isbn='067232959', name='Sams Teach Yourself Django in 24 Hours', + pages=528, rating=3.0, price=Decimal('23.09'), contact=cls.a3, publisher=cls.p2, + pubdate=datetime.date(2008, 3, 3) + ) + cls.b3 = Book.objects.create( + isbn='159059996', name='Practical Django Projects', + pages=300, rating=4.0, price=Decimal('29.69'), contact=cls.a4, publisher=cls.p1, + pubdate=datetime.date(2008, 6, 23) + ) + cls.b4 = Book.objects.create( + isbn='013235613', name='Python Web Development with Django', + pages=350, rating=4.0, price=Decimal('29.69'), contact=cls.a5, publisher=cls.p3, + pubdate=datetime.date(2008, 11, 3) + ) + cls.b5 = Book.objects.create( + isbn='013790395', name='Artificial Intelligence: A Modern Approach', + pages=1132, rating=4.0, price=Decimal('82.80'), contact=cls.a8, publisher=cls.p3, + pubdate=datetime.date(1995, 1, 15) + ) + cls.b6 = Book.objects.create( + isbn='155860191', name='Paradigms of Artificial Intelligence Programming: Case Studies in Common Lisp', + pages=946, rating=5.0, price=Decimal('75.00'), contact=cls.a8, publisher=cls.p4, + pubdate=datetime.date(1991, 10, 15) + ) + cls.b1.authors.add(cls.a1, cls.a2) + cls.b2.authors.add(cls.a3) + cls.b3.authors.add(cls.a4) + cls.b4.authors.add(cls.a5, cls.a6, cls.a7) + cls.b5.authors.add(cls.a8, cls.a9) + cls.b6.authors.add(cls.a8) + + s1 = Store.objects.create( + name='Amazon.com', + original_opening=datetime.datetime(1994, 4, 23, 9, 17, 42), + friday_night_closing=datetime.time(23, 59, 59) + ) + s2 = Store.objects.create( + name='Books.com', + original_opening=datetime.datetime(2001, 3, 15, 11, 23, 37), + friday_night_closing=datetime.time(23, 59, 59) + ) + s3 = Store.objects.create( + name="Mamma and Pappa's Books", + original_opening=datetime.datetime(1945, 4, 25, 16, 24, 14), + friday_night_closing=datetime.time(21, 30) + ) + s1.books.add(cls.b1, cls.b2, cls.b3, cls.b4, cls.b5, cls.b6) + s2.books.add(cls.b1, cls.b3, cls.b5, cls.b6) + s3.books.add(cls.b3, cls.b4, cls.b6) + + def test_empty_aggregate(self): + self.assertEqual(Author.objects.all().aggregate(), {}) + + def test_aggregate_in_order_by(self): + msg = ( + 'Using an aggregate in order_by() without also including it in ' + 'annotate() is not allowed: Avg(F(book__rating)' + ) + with self.assertRaisesMessage(FieldError, msg): + Author.objects.values('age').order_by(Avg('book__rating')) + + def test_single_aggregate(self): + vals = Author.objects.aggregate(Avg("age")) + self.assertEqual(vals, {"age__avg": Approximate(37.4, places=1)}) + + def test_multiple_aggregates(self): + vals = Author.objects.aggregate(Sum("age"), Avg("age")) + self.assertEqual(vals, {"age__sum": 337, "age__avg": Approximate(37.4, places=1)}) + + def test_filter_aggregate(self): + vals = Author.objects.filter(age__gt=29).aggregate(Sum("age")) + self.assertEqual(vals, {'age__sum': 254}) + + def test_related_aggregate(self): + vals = Author.objects.aggregate(Avg("friends__age")) + self.assertEqual(vals, {'friends__age__avg': Approximate(34.07, places=2)}) + + vals = Book.objects.filter(rating__lt=4.5).aggregate(Avg("authors__age")) + self.assertEqual(vals, {'authors__age__avg': Approximate(38.2857, places=2)}) + + vals = Author.objects.all().filter(name__contains="a").aggregate(Avg("book__rating")) + self.assertEqual(vals, {'book__rating__avg': 4.0}) + + vals = Book.objects.aggregate(Sum("publisher__num_awards")) + self.assertEqual(vals, {'publisher__num_awards__sum': 30}) + + vals = Publisher.objects.aggregate(Sum("book__price")) + self.assertEqual(vals, {'book__price__sum': Decimal('270.27')}) + + def test_aggregate_multi_join(self): + vals = Store.objects.aggregate(Max("books__authors__age")) + self.assertEqual(vals, {'books__authors__age__max': 57}) + + vals = Author.objects.aggregate(Min("book__publisher__num_awards")) + self.assertEqual(vals, {'book__publisher__num_awards__min': 1}) + + def test_aggregate_alias(self): + vals = Store.objects.filter(name="Amazon.com").aggregate(amazon_mean=Avg("books__rating")) + self.assertEqual(vals, {'amazon_mean': Approximate(4.08, places=2)}) + + def test_annotate_basic(self): + self.assertQuerysetEqual( + Book.objects.annotate().order_by('pk'), [ + "The Definitive Guide to Django: Web Development Done Right", + "Sams Teach Yourself Django in 24 Hours", + "Practical Django Projects", + "Python Web Development with Django", + "Artificial Intelligence: A Modern Approach", + "Paradigms of Artificial Intelligence Programming: Case Studies in Common Lisp" + ], + lambda b: b.name + ) + + books = Book.objects.annotate(mean_age=Avg("authors__age")) + b = books.get(pk=self.b1.pk) + self.assertEqual( + b.name, + 'The Definitive Guide to Django: Web Development Done Right' + ) + self.assertEqual(b.mean_age, 34.5) + + def test_annotate_defer(self): + qs = Book.objects.annotate( + page_sum=Sum("pages")).defer('name').filter(pk=self.b1.pk) + + rows = [ + (self.b1.id, "159059725", 447, "The Definitive Guide to Django: Web Development Done Right") + ] + self.assertQuerysetEqual( + qs.order_by('pk'), rows, + lambda r: (r.id, r.isbn, r.page_sum, r.name) + ) + + def test_annotate_defer_select_related(self): + qs = Book.objects.select_related('contact').annotate( + page_sum=Sum("pages")).defer('name').filter(pk=self.b1.pk) + + rows = [ + (self.b1.id, "159059725", 447, "Adrian Holovaty", + "The Definitive Guide to Django: Web Development Done Right") + ] + self.assertQuerysetEqual( + qs.order_by('pk'), rows, + lambda r: (r.id, r.isbn, r.page_sum, r.contact.name, r.name) + ) + + def test_annotate_m2m(self): + books = Book.objects.filter(rating__lt=4.5).annotate(Avg("authors__age")).order_by("name") + self.assertQuerysetEqual( + books, [ + ('Artificial Intelligence: A Modern Approach', 51.5), + ('Practical Django Projects', 29.0), + ('Python Web Development with Django', Approximate(30.3, places=1)), + ('Sams Teach Yourself Django in 24 Hours', 45.0) + ], + lambda b: (b.name, b.authors__age__avg), + ) + + books = Book.objects.annotate(num_authors=Count("authors")).order_by("name") + self.assertQuerysetEqual( + books, [ + ('Artificial Intelligence: A Modern Approach', 2), + ('Paradigms of Artificial Intelligence Programming: Case Studies in Common Lisp', 1), + ('Practical Django Projects', 1), + ('Python Web Development with Django', 3), + ('Sams Teach Yourself Django in 24 Hours', 1), + ('The Definitive Guide to Django: Web Development Done Right', 2) + ], + lambda b: (b.name, b.num_authors) + ) + + def test_backwards_m2m_annotate(self): + authors = Author.objects.filter(name__contains="a").annotate(Avg("book__rating")).order_by("name") + self.assertQuerysetEqual( + authors, [ + ('Adrian Holovaty', 4.5), + ('Brad Dayley', 3.0), + ('Jacob Kaplan-Moss', 4.5), + ('James Bennett', 4.0), + ('Paul Bissex', 4.0), + ('Stuart Russell', 4.0) + ], + lambda a: (a.name, a.book__rating__avg) + ) + + authors = Author.objects.annotate(num_books=Count("book")).order_by("name") + self.assertQuerysetEqual( + authors, [ + ('Adrian Holovaty', 1), + ('Brad Dayley', 1), + ('Jacob Kaplan-Moss', 1), + ('James Bennett', 1), + ('Jeffrey Forcier', 1), + ('Paul Bissex', 1), + ('Peter Norvig', 2), + ('Stuart Russell', 1), + ('Wesley J. Chun', 1) + ], + lambda a: (a.name, a.num_books) + ) + + def test_reverse_fkey_annotate(self): + books = Book.objects.annotate(Sum("publisher__num_awards")).order_by("name") + self.assertQuerysetEqual( + books, [ + ('Artificial Intelligence: A Modern Approach', 7), + ('Paradigms of Artificial Intelligence Programming: Case Studies in Common Lisp', 9), + ('Practical Django Projects', 3), + ('Python Web Development with Django', 7), + ('Sams Teach Yourself Django in 24 Hours', 1), + ('The Definitive Guide to Django: Web Development Done Right', 3) + ], + lambda b: (b.name, b.publisher__num_awards__sum) + ) + + publishers = Publisher.objects.annotate(Sum("book__price")).order_by("name") + self.assertQuerysetEqual( + publishers, [ + ('Apress', Decimal("59.69")), + ("Jonno's House of Books", None), + ('Morgan Kaufmann', Decimal("75.00")), + ('Prentice Hall', Decimal("112.49")), + ('Sams', Decimal("23.09")) + ], + lambda p: (p.name, p.book__price__sum) + ) + + def test_annotate_values(self): + books = list(Book.objects.filter(pk=self.b1.pk).annotate(mean_age=Avg("authors__age")).values()) + self.assertEqual( + books, [ + { + "contact_id": self.a1.id, + "id": self.b1.id, + "isbn": "159059725", + "mean_age": 34.5, + "name": "The Definitive Guide to Django: Web Development Done Right", + "pages": 447, + "price": Approximate(Decimal("30")), + "pubdate": datetime.date(2007, 12, 6), + "publisher_id": self.p1.id, + "rating": 4.5, + } + ] + ) + + books = ( + Book.objects + .filter(pk=self.b1.pk) + .annotate(mean_age=Avg('authors__age')) + .values('pk', 'isbn', 'mean_age') + ) + self.assertEqual( + list(books), [ + { + "pk": self.b1.pk, + "isbn": "159059725", + "mean_age": 34.5, + } + ] + ) + + books = Book.objects.filter(pk=self.b1.pk).annotate(mean_age=Avg("authors__age")).values("name") + self.assertEqual( + list(books), + [{'name': 'The Definitive Guide to Django: Web Development Done Right'}], + ) + + books = Book.objects.filter(pk=self.b1.pk).values().annotate(mean_age=Avg('authors__age')) + self.assertEqual( + list(books), [ + { + "contact_id": self.a1.id, + "id": self.b1.id, + "isbn": "159059725", + "mean_age": 34.5, + "name": "The Definitive Guide to Django: Web Development Done Right", + "pages": 447, + "price": Approximate(Decimal("30")), + "pubdate": datetime.date(2007, 12, 6), + "publisher_id": self.p1.id, + "rating": 4.5, + } + ] + ) + + books = ( + Book.objects + .values("rating") + .annotate(n_authors=Count("authors__id"), mean_age=Avg("authors__age")) + .order_by("rating") + ) + self.assertEqual( + list(books), [ + { + "rating": 3.0, + "n_authors": 1, + "mean_age": 45.0, + }, + { + "rating": 4.0, + "n_authors": 6, + "mean_age": Approximate(37.16, places=1) + }, + { + "rating": 4.5, + "n_authors": 2, + "mean_age": 34.5, + }, + { + "rating": 5.0, + "n_authors": 1, + "mean_age": 57.0, + } + ] + ) + + authors = Author.objects.annotate(Avg("friends__age")).order_by("name") + self.assertQuerysetEqual( + authors, [ + ('Adrian Holovaty', 32.0), + ('Brad Dayley', None), + ('Jacob Kaplan-Moss', 29.5), + ('James Bennett', 34.0), + ('Jeffrey Forcier', 27.0), + ('Paul Bissex', 31.0), + ('Peter Norvig', 46.0), + ('Stuart Russell', 57.0), + ('Wesley J. Chun', Approximate(33.66, places=1)) + ], + lambda a: (a.name, a.friends__age__avg) + ) + + def test_count(self): + vals = Book.objects.aggregate(Count("rating")) + self.assertEqual(vals, {"rating__count": 6}) + + vals = Book.objects.aggregate(Count("rating", distinct=True)) + self.assertEqual(vals, {"rating__count": 4}) + + #def test_count_star(self): + # with self.assertNumQueries(1) as ctx: + # Book.objects.aggregate(n=Count("*")) + # sql = ctx.captured_queries[0]['sql'] + # self.assertIn('SELECT COUNT(*) ', sql) + + def test_non_grouped_annotation_not_in_group_by(self): + """ + An annotation not included in values() before an aggregate should be + excluded from the group by clause. + """ + qs = ( + Book.objects.annotate(xprice=F('price')).filter(rating=4.0).values('rating') + .annotate(count=Count('publisher_id', distinct=True)).values('count', 'rating').order_by('count') + ) + self.assertEqual(list(qs), [{'rating': 4.0, 'count': 2}]) + + def test_grouped_annotation_in_group_by(self): + """ + An annotation included in values() before an aggregate should be + included in the group by clause. + """ + qs = ( + Book.objects.annotate(xprice=F('price')).filter(rating=4.0).values('rating', 'xprice') + .annotate(count=Count('publisher_id', distinct=True)).values('count', 'rating').order_by('count') + ) + self.assertEqual( + list(qs), [ + {'rating': 4.0, 'count': 1}, + {'rating': 4.0, 'count': 2}, + ] + ) + + def test_fkey_aggregate(self): + explicit = list(Author.objects.annotate(Count('book__id'))) + implicit = list(Author.objects.annotate(Count('book'))) + self.assertEqual(explicit, implicit) + + def test_annotate_ordering(self): + books = Book.objects.values('rating').annotate(oldest=Max('authors__age')).order_by('oldest', 'rating') + self.assertEqual( + list(books), [ + {'rating': 4.5, 'oldest': 35}, + {'rating': 3.0, 'oldest': 45}, + {'rating': 4.0, 'oldest': 57}, + {'rating': 5.0, 'oldest': 57}, + ] + ) + + books = Book.objects.values("rating").annotate(oldest=Max("authors__age")).order_by("-oldest", "-rating") + self.assertEqual( + list(books), [ + {'rating': 5.0, 'oldest': 57}, + {'rating': 4.0, 'oldest': 57}, + {'rating': 3.0, 'oldest': 45}, + {'rating': 4.5, 'oldest': 35}, + ] + ) + + def test_aggregate_annotation(self): + vals = Book.objects.annotate(num_authors=Count("authors__id")).aggregate(Avg("num_authors")) + self.assertEqual(vals, {"num_authors__avg": Approximate(1.66, places=1)}) + + def test_avg_duration_field(self): + # Explicit `output_field`. + self.assertEqual( + Publisher.objects.aggregate(Avg('duration', output_field=DurationField())), + {'duration__avg': datetime.timedelta(days=1, hours=12)} + ) + # Implicit `output_field`. + self.assertEqual( + Publisher.objects.aggregate(Avg('duration')), + {'duration__avg': datetime.timedelta(days=1, hours=12)} + ) + + def test_sum_duration_field(self): + self.assertEqual( + Publisher.objects.aggregate(Sum('duration', output_field=DurationField())), + {'duration__sum': datetime.timedelta(days=3)} + ) + + def test_sum_distinct_aggregate(self): + """ + Sum on a distinct() QuerySet should aggregate only the distinct items. + """ + authors = Author.objects.filter(book__in=[self.b5, self.b6]) + self.assertEqual(authors.count(), 3) + + distinct_authors = authors.distinct() + self.assertEqual(distinct_authors.count(), 2) + + # Selected author ages are 57 and 46 + age_sum = distinct_authors.aggregate(Sum('age')) + self.assertEqual(age_sum['age__sum'], 103) + + def test_filtering(self): + p = Publisher.objects.create(name='Expensive Publisher', num_awards=0) + Book.objects.create( + name='ExpensiveBook1', + pages=1, + isbn='111', + rating=3.5, + price=Decimal("1000"), + publisher=p, + contact_id=self.a1.id, + pubdate=datetime.date(2008, 12, 1) + ) + Book.objects.create( + name='ExpensiveBook2', + pages=1, + isbn='222', + rating=4.0, + price=Decimal("1000"), + publisher=p, + contact_id=self.a1.id, + pubdate=datetime.date(2008, 12, 2) + ) + Book.objects.create( + name='ExpensiveBook3', + pages=1, + isbn='333', + rating=4.5, + price=Decimal("35"), + publisher=p, + contact_id=self.a1.id, + pubdate=datetime.date(2008, 12, 3) + ) + + publishers = Publisher.objects.annotate(num_books=Count("book__id")).filter(num_books__gt=1).order_by("pk") + self.assertQuerysetEqual( + publishers, + ['Apress', 'Prentice Hall', 'Expensive Publisher'], + lambda p: p.name, + ) + + publishers = Publisher.objects.filter(book__price__lt=Decimal("40.0")).order_by("pk") + self.assertQuerysetEqual( + publishers, [ + "Apress", + "Apress", + "Sams", + "Prentice Hall", + "Expensive Publisher", + ], + lambda p: p.name + ) + + publishers = ( + Publisher.objects + .annotate(num_books=Count("book__id")) + .filter(num_books__gt=1, book__price__lt=Decimal("40.0")) + .order_by("pk") + ) + self.assertQuerysetEqual( + publishers, + ['Apress', 'Prentice Hall', 'Expensive Publisher'], + lambda p: p.name, + ) + + publishers = ( + Publisher.objects + .filter(book__price__lt=Decimal("40.0")) + .annotate(num_books=Count("book__id")) + .filter(num_books__gt=1) + .order_by("pk") + ) + self.assertQuerysetEqual(publishers, ['Apress'], lambda p: p.name) + + publishers = Publisher.objects.annotate(num_books=Count("book")).filter(num_books__range=[1, 3]).order_by("pk") + self.assertQuerysetEqual( + publishers, [ + "Apress", + "Sams", + "Prentice Hall", + "Morgan Kaufmann", + "Expensive Publisher", + ], + lambda p: p.name + ) + + publishers = Publisher.objects.annotate(num_books=Count("book")).filter(num_books__range=[1, 2]).order_by("pk") + self.assertQuerysetEqual( + publishers, + ['Apress', 'Sams', 'Prentice Hall', 'Morgan Kaufmann'], + lambda p: p.name + ) + + publishers = Publisher.objects.annotate(num_books=Count("book")).filter(num_books__in=[1, 3]).order_by("pk") + self.assertQuerysetEqual( + publishers, + ['Sams', 'Morgan Kaufmann', 'Expensive Publisher'], + lambda p: p.name, + ) + + publishers = Publisher.objects.annotate(num_books=Count("book")).filter(num_books__isnull=True) + self.assertEqual(len(publishers), 0) + + def test_annotation(self): + vals = Author.objects.filter(pk=self.a1.pk).aggregate(Count("friends__id")) + self.assertEqual(vals, {"friends__id__count": 2}) + + books = Book.objects.annotate(num_authors=Count("authors__name")).filter(num_authors__exact=2).order_by("pk") + self.assertQuerysetEqual( + books, [ + "The Definitive Guide to Django: Web Development Done Right", + "Artificial Intelligence: A Modern Approach", + ], + lambda b: b.name + ) + + authors = ( + Author.objects + .annotate(num_friends=Count("friends__id", distinct=True)) + .filter(num_friends=0) + .order_by("pk") + ) + self.assertQuerysetEqual(authors, ['Brad Dayley'], lambda a: a.name) + + publishers = Publisher.objects.annotate(num_books=Count("book__id")).filter(num_books__gt=1).order_by("pk") + self.assertQuerysetEqual(publishers, ['Apress', 'Prentice Hall'], lambda p: p.name) + + publishers = ( + Publisher.objects + .filter(book__price__lt=Decimal("40.0")) + .annotate(num_books=Count("book__id")) + .filter(num_books__gt=1) + ) + self.assertQuerysetEqual(publishers, ['Apress'], lambda p: p.name) + + books = ( + Book.objects + .annotate(num_authors=Count("authors__id")) + .filter(authors__name__contains="Norvig", num_authors__gt=1) + ) + self.assertQuerysetEqual( + books, + ['Artificial Intelligence: A Modern Approach'], + lambda b: b.name + ) + + def test_more_aggregation(self): + a = Author.objects.get(name__contains='Norvig') + b = Book.objects.get(name__contains='Done Right') + b.authors.add(a) + b.save() + + vals = ( + Book.objects + .annotate(num_authors=Count("authors__id")) + .filter(authors__name__contains="Norvig", num_authors__gt=1) + .aggregate(Avg("rating")) + ) + self.assertEqual(vals, {"rating__avg": 4.25}) + + def test_even_more_aggregate(self): + publishers = Publisher.objects.annotate( + earliest_book=Min("book__pubdate"), + ).exclude(earliest_book=None).order_by("earliest_book").values( + 'earliest_book', + 'num_awards', + 'id', + 'name', + ) + self.assertEqual( + list(publishers), [ + { + 'earliest_book': datetime.date(1991, 10, 15), + 'num_awards': 9, + 'id': self.p4.id, + 'name': 'Morgan Kaufmann' + }, + { + 'earliest_book': datetime.date(1995, 1, 15), + 'num_awards': 7, + 'id': self.p3.id, + 'name': 'Prentice Hall' + }, + { + 'earliest_book': datetime.date(2007, 12, 6), + 'num_awards': 3, + 'id': self.p1.id, + 'name': 'Apress' + }, + { + 'earliest_book': datetime.date(2008, 3, 3), + 'num_awards': 1, + 'id': self.p2.id, + 'name': 'Sams' + } + ] + ) + + vals = Store.objects.aggregate(Max("friday_night_closing"), Min("original_opening")) + self.assertEqual( + vals, + { + "friday_night_closing__max": datetime.time(23, 59, 59), + "original_opening__min": datetime.datetime(1945, 4, 25, 16, 24, 14), + } + ) + + def test_annotate_values_list(self): + books = ( + Book.objects + .filter(pk=self.b1.pk) + .annotate(mean_age=Avg("authors__age")) + .values_list("pk", "isbn", "mean_age") + ) + self.assertEqual(list(books), [(self.b1.id, '159059725', 34.5)]) + + books = Book.objects.filter(pk=self.b1.pk).annotate(mean_age=Avg("authors__age")).values_list("isbn") + self.assertEqual(list(books), [('159059725',)]) + + books = Book.objects.filter(pk=self.b1.pk).annotate(mean_age=Avg("authors__age")).values_list("mean_age") + self.assertEqual(list(books), [(34.5,)]) + + books = ( + Book.objects + .filter(pk=self.b1.pk) + .annotate(mean_age=Avg("authors__age")) + .values_list("mean_age", flat=True) + ) + self.assertEqual(list(books), [34.5]) + + books = Book.objects.values_list("price").annotate(count=Count("price")).order_by("-count", "price") + self.assertEqual( + list(books), [ + (Decimal("29.69"), 2), + (Decimal('23.09'), 1), + (Decimal('30'), 1), + (Decimal('75'), 1), + (Decimal('82.8'), 1), + ] + ) + + def test_dates_with_aggregation(self): + """ + .dates() returns a distinct set of dates when applied to a + QuerySet with aggregation. + + Refs #18056. Previously, .dates() would return distinct (date_kind, + aggregation) sets, in this case (year, num_authors), so 2008 would be + returned twice because there are books from 2008 with a different + number of authors. + """ + dates = Book.objects.annotate(num_authors=Count("authors")).dates('pubdate', 'year') + self.assertQuerysetEqual( + dates, [ + "datetime.date(1991, 1, 1)", + "datetime.date(1995, 1, 1)", + "datetime.date(2007, 1, 1)", + "datetime.date(2008, 1, 1)" + ] + ) + + def test_values_aggregation(self): + # Refs #20782 + max_rating = Book.objects.values('rating').aggregate(max_rating=Max('rating')) + self.assertEqual(max_rating['max_rating'], 5) + max_books_per_rating = Book.objects.values('rating').annotate( + books_per_rating=Count('id') + ).aggregate(Max('books_per_rating')) + self.assertEqual( + max_books_per_rating, + {'books_per_rating__max': 3}) + + def test_ticket17424(self): + """ + Doing exclude() on a foreign model after annotate() doesn't crash. + """ + all_books = list(Book.objects.values_list('pk', flat=True).order_by('pk')) + annotated_books = Book.objects.order_by('pk').annotate(one=Count("id")) + + # The value doesn't matter, we just need any negative + # constraint on a related model that's a noop. + excluded_books = annotated_books.exclude(publisher__name="__UNLIKELY_VALUE__") + + # Try to generate query tree + str(excluded_books.query) + + self.assertQuerysetEqual(excluded_books, all_books, lambda x: x.pk) + + # Check internal state + self.assertIsNone(annotated_books.query.alias_map["aggregation_book"].join_type) + self.assertIsNone(excluded_books.query.alias_map["aggregation_book"].join_type) + + def test_ticket12886(self): + """ + Aggregation over sliced queryset works correctly. + """ + qs = Book.objects.all().order_by('-rating')[0:3] + vals = qs.aggregate(average_top3_rating=Avg('rating'))['average_top3_rating'] + self.assertAlmostEqual(vals, 4.5, places=2) + + def test_ticket11881(self): + """ + Subqueries do not needlessly contain ORDER BY, SELECT FOR UPDATE or + select_related() stuff. + """ + qs = Book.objects.all().select_for_update().order_by( + 'pk').select_related('publisher').annotate(max_pk=Max('pk')) + with CaptureQueriesContext(connection) as captured_queries: + qs.aggregate(avg_pk=Avg('max_pk')) + self.assertEqual(len(captured_queries), 1) + qstr = captured_queries[0]['sql'].lower() + self.assertNotIn('for update', qstr) + forced_ordering = connection.ops.force_no_ordering() + if forced_ordering: + # If the backend needs to force an ordering we make sure it's + # the only "ORDER BY" clause present in the query. + self.assertEqual( + re.findall(r'order by (\w+)', qstr), + [', '.join(f[1][0] for f in forced_ordering).lower()] + ) + else: + self.assertNotIn('order by', qstr) + self.assertEqual(qstr.count(' join '), 0) + + def test_decimal_max_digits_has_no_effect(self): + Book.objects.all().delete() + a1 = Author.objects.first() + p1 = Publisher.objects.first() + thedate = timezone.now() + for i in range(10): + Book.objects.create( + isbn="abcde{}".format(i), name="none", pages=10, rating=4.0, + price=9999.98, contact=a1, publisher=p1, pubdate=thedate) + + book = Book.objects.aggregate(price_sum=Sum('price')) + self.assertEqual(book['price_sum'], Decimal("99999.80")) + + def test_nonaggregate_aggregation_throws(self): + with self.assertRaisesMessage(TypeError, 'fail is not an aggregate expression'): + Book.objects.aggregate(fail=F('price')) + + def test_nonfield_annotation(self): + book = Book.objects.annotate(val=Max(Value(2, output_field=IntegerField()))).first() + self.assertEqual(book.val, 2) + book = Book.objects.annotate(val=Max(Value(2), output_field=IntegerField())).first() + self.assertEqual(book.val, 2) + book = Book.objects.annotate(val=Max(2, output_field=IntegerField())).first() + self.assertEqual(book.val, 2) + + def test_missing_output_field_raises_error(self): + with self.assertRaisesMessage(FieldError, 'Cannot resolve expression type, unknown output_field'): + Book.objects.annotate(val=Max(2)).first() + + def test_annotation_expressions(self): + authors = Author.objects.annotate(combined_ages=Sum(F('age') + F('friends__age'))).order_by('name') + authors2 = Author.objects.annotate(combined_ages=Sum('age') + Sum('friends__age')).order_by('name') + for qs in (authors, authors2): + self.assertQuerysetEqual( + qs, [ + ('Adrian Holovaty', 132), + ('Brad Dayley', None), + ('Jacob Kaplan-Moss', 129), + ('James Bennett', 63), + ('Jeffrey Forcier', 128), + ('Paul Bissex', 120), + ('Peter Norvig', 103), + ('Stuart Russell', 103), + ('Wesley J. Chun', 176) + ], + lambda a: (a.name, a.combined_ages) + ) + + def test_aggregation_expressions(self): + a1 = Author.objects.aggregate(av_age=Sum('age') / Count('*')) + a2 = Author.objects.aggregate(av_age=Sum('age') / Count('age')) + a3 = Author.objects.aggregate(av_age=Avg('age')) + self.assertEqual(a1, {'av_age': 37}) + self.assertEqual(a2, {'av_age': 37}) + self.assertEqual(a3, {'av_age': Approximate(37.4, places=1)}) + + def test_avg_decimal_field(self): + v = Book.objects.filter(rating=4).aggregate(avg_price=(Avg('price')))['avg_price'] + self.assertIsInstance(v, float) + self.assertEqual(v, Approximate(47.39, places=2)) + + def test_order_of_precedence(self): + p1 = Book.objects.filter(rating=4).aggregate(avg_price=(Avg('price') + 2) * 3) + self.assertEqual(p1, {'avg_price': Approximate(148.18, places=2)}) + + p2 = Book.objects.filter(rating=4).aggregate(avg_price=Avg('price') + 2 * 3) + self.assertEqual(p2, {'avg_price': Approximate(53.39, places=2)}) + + def test_combine_different_types(self): + msg = 'Expression contains mixed types. You must set output_field.' + qs = Book.objects.annotate(sums=Sum('rating') + Sum('pages') + Sum('price')) + with self.assertRaisesMessage(FieldError, msg): + qs.first() + with self.assertRaisesMessage(FieldError, msg): + qs.first() + + b1 = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'), + output_field=IntegerField())).get(pk=self.b4.pk) + self.assertEqual(b1.sums, 383) + + b2 = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'), + output_field=FloatField())).get(pk=self.b4.pk) + self.assertEqual(b2.sums, 383.69) + + b3 = Book.objects.annotate(sums=Sum(F('rating') + F('pages') + F('price'), + output_field=DecimalField())).get(pk=self.b4.pk) + self.assertEqual(b3.sums, Approximate(Decimal("383.69"), places=2)) + + def test_complex_aggregations_require_kwarg(self): + with self.assertRaisesMessage(TypeError, 'Complex annotations require an alias'): + Author.objects.annotate(Sum(F('age') + F('friends__age'))) + with self.assertRaisesMessage(TypeError, 'Complex aggregates require an alias'): + Author.objects.aggregate(Sum('age') / Count('age')) + with self.assertRaisesMessage(TypeError, 'Complex aggregates require an alias'): + Author.objects.aggregate(Sum(1)) + + def test_aggregate_over_complex_annotation(self): + qs = Author.objects.annotate( + combined_ages=Sum(F('age') + F('friends__age'))) + + age = qs.aggregate(max_combined_age=Max('combined_ages')) + self.assertEqual(age['max_combined_age'], 176) + + age = qs.aggregate(max_combined_age_doubled=Max('combined_ages') * 2) + self.assertEqual(age['max_combined_age_doubled'], 176 * 2) + + age = qs.aggregate( + max_combined_age_doubled=Max('combined_ages') + Max('combined_ages')) + self.assertEqual(age['max_combined_age_doubled'], 176 * 2) + + age = qs.aggregate( + max_combined_age_doubled=Max('combined_ages') + Max('combined_ages'), + sum_combined_age=Sum('combined_ages')) + self.assertEqual(age['max_combined_age_doubled'], 176 * 2) + self.assertEqual(age['sum_combined_age'], 954) + + age = qs.aggregate( + max_combined_age_doubled=Max('combined_ages') + Max('combined_ages'), + sum_combined_age_doubled=Sum('combined_ages') + Sum('combined_ages')) + self.assertEqual(age['max_combined_age_doubled'], 176 * 2) + self.assertEqual(age['sum_combined_age_doubled'], 954 * 2) + + def test_values_annotation_with_expression(self): + # ensure the F() is promoted to the group by clause + qs = Author.objects.values('name').annotate(another_age=Sum('age') + F('age')) + a = qs.get(name="Adrian Holovaty") + self.assertEqual(a['another_age'], 68) + + qs = qs.annotate(friend_count=Count('friends')) + a = qs.get(name="Adrian Holovaty") + self.assertEqual(a['friend_count'], 2) + + qs = qs.annotate(combined_age=Sum('age') + F('friends__age')).filter( + name="Adrian Holovaty").order_by('-combined_age') + self.assertEqual( + list(qs), [ + { + "name": 'Adrian Holovaty', + "another_age": 68, + "friend_count": 1, + "combined_age": 69 + }, + { + "name": 'Adrian Holovaty', + "another_age": 68, + "friend_count": 1, + "combined_age": 63 + } + ] + ) + + vals = qs.values('name', 'combined_age') + self.assertEqual( + list(vals), [ + {'name': 'Adrian Holovaty', 'combined_age': 69}, + {'name': 'Adrian Holovaty', 'combined_age': 63}, + ] + ) + + def test_annotate_values_aggregate(self): + alias_age = Author.objects.annotate( + age_alias=F('age') + ).values( + 'age_alias', + ).aggregate(sum_age=Sum('age_alias')) + + age = Author.objects.values('age').aggregate(sum_age=Sum('age')) + + self.assertEqual(alias_age['sum_age'], age['sum_age']) + + def test_annotate_over_annotate(self): + author = Author.objects.annotate( + age_alias=F('age') + ).annotate( + sum_age=Sum('age_alias') + ).get(name="Adrian Holovaty") + + other_author = Author.objects.annotate( + sum_age=Sum('age') + ).get(name="Adrian Holovaty") + + self.assertEqual(author.sum_age, other_author.sum_age) + + def test_annotated_aggregate_over_annotated_aggregate(self): + with self.assertRaisesMessage(FieldError, "Cannot compute Sum('id__max'): 'id__max' is an aggregate"): + Book.objects.annotate(Max('id')).annotate(Sum('id__max')) + + class MyMax(Max): + def as_sql(self, compiler, connection): + self.set_source_expressions(self.get_source_expressions()[0:1]) + return super().as_sql(compiler, connection) + + with self.assertRaisesMessage(FieldError, "Cannot compute Max('id__max'): 'id__max' is an aggregate"): + Book.objects.annotate(Max('id')).annotate(my_max=MyMax('id__max', 'price')) + + def test_multi_arg_aggregate(self): + class MyMax(Max): + output_field = DecimalField() + + def as_sql(self, compiler, connection): + copy = self.copy() + copy.set_source_expressions(copy.get_source_expressions()[0:1]) + return super(MyMax, copy).as_sql(compiler, connection) + + with self.assertRaisesMessage(TypeError, 'Complex aggregates require an alias'): + Book.objects.aggregate(MyMax('pages', 'price')) + + with self.assertRaisesMessage(TypeError, 'Complex annotations require an alias'): + Book.objects.annotate(MyMax('pages', 'price')) + + Book.objects.aggregate(max_field=MyMax('pages', 'price')) + + def test_add_implementation(self): + class MySum(Sum): + pass + + # test completely changing how the output is rendered + def lower_case_function_override(self, compiler, connection): + sql, params = compiler.compile(self.source_expressions[0]) + substitutions = {'function': self.function.lower(), 'expressions': sql} + substitutions.update(self.extra) + return self.template % substitutions, params + setattr(MySum, 'as_' + connection.vendor, lower_case_function_override) + + qs = Book.objects.annotate( + sums=MySum(F('rating') + F('pages') + F('price'), output_field=IntegerField()) + ) + self.assertEqual(str(qs.query).count('sum('), 1) + b1 = qs.get(pk=self.b4.pk) + self.assertEqual(b1.sums, 383) + + # test changing the dict and delegating + def lower_case_function_super(self, compiler, connection): + self.extra['function'] = self.function.lower() + return super(MySum, self).as_sql(compiler, connection) + setattr(MySum, 'as_' + connection.vendor, lower_case_function_super) + + qs = Book.objects.annotate( + sums=MySum(F('rating') + F('pages') + F('price'), output_field=IntegerField()) + ) + self.assertEqual(str(qs.query).count('sum('), 1) + b1 = qs.get(pk=self.b4.pk) + self.assertEqual(b1.sums, 383) + + # test overriding all parts of the template + def be_evil(self, compiler, connection): + substitutions = {'function': 'MAX', 'expressions': '2'} + substitutions.update(self.extra) + return self.template % substitutions, () + setattr(MySum, 'as_' + connection.vendor, be_evil) + + qs = Book.objects.annotate( + sums=MySum(F('rating') + F('pages') + F('price'), output_field=IntegerField()) + ) + self.assertEqual(str(qs.query).count('MAX('), 1) + b1 = qs.get(pk=self.b4.pk) + self.assertEqual(b1.sums, 2) + + def test_complex_values_aggregation(self): + max_rating = Book.objects.values('rating').aggregate( + double_max_rating=Max('rating') + Max('rating')) + self.assertEqual(max_rating['double_max_rating'], 5 * 2) + + max_books_per_rating = Book.objects.values('rating').annotate( + books_per_rating=Count('id') + 5 + ).aggregate(Max('books_per_rating')) + self.assertEqual( + max_books_per_rating, + {'books_per_rating__max': 3 + 5}) + +# def test_expression_on_aggregation(self): +# +# # Create a plain expression +# class Greatest(Func): +# function = 'GREATEST' +# +# def as_sqlite(self, compiler, connection): +# return super().as_sql(compiler, connection, function='MAX') +# +# qs = Publisher.objects.annotate( +# price_or_median=Greatest(Avg('book__rating'), Avg('book__price')) +# ).filter(price_or_median__gte=F('num_awards')).order_by('num_awards') +# self.assertQuerysetEqual( +# qs, [1, 3, 7, 9], lambda v: v.num_awards) +# +# qs2 = Publisher.objects.annotate( +# rating_or_num_awards=Greatest(Avg('book__rating'), F('num_awards'), +# output_field=FloatField()) +# ).filter(rating_or_num_awards__gt=F('num_awards')).order_by('num_awards') +# self.assertQuerysetEqual( +# qs2, [1, 3], lambda v: v.num_awards) + + def test_arguments_must_be_expressions(self): + msg = 'QuerySet.aggregate() received non-expression(s): %s.' + with self.assertRaisesMessage(TypeError, msg % FloatField()): + Book.objects.aggregate(FloatField()) + with self.assertRaisesMessage(TypeError, msg % True): + Book.objects.aggregate(is_book=True) + with self.assertRaisesMessage(TypeError, msg % ', '.join([str(FloatField()), 'True'])): + Book.objects.aggregate(FloatField(), Avg('price'), is_book=True) diff --git a/tests/aggregation_regress/__init__.py b/tests/aggregation_regress/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/aggregation_regress/models.py b/tests/aggregation_regress/models.py new file mode 100644 index 00000000..3498cbf1 --- /dev/null +++ b/tests/aggregation_regress/models.py @@ -0,0 +1,104 @@ +from django.contrib.contenttypes.fields import ( + GenericForeignKey, GenericRelation, +) +from django.contrib.contenttypes.models import ContentType +from django.db import models + + +class Author(models.Model): + name = models.CharField(max_length=100) + age = models.IntegerField() + friends = models.ManyToManyField('self', blank=True) + + def __str__(self): + return self.name + + +class Publisher(models.Model): + name = models.CharField(max_length=255) + num_awards = models.IntegerField() + + def __str__(self): + return self.name + + +class ItemTag(models.Model): + tag = models.CharField(max_length=100) + content_type = models.ForeignKey(ContentType, models.CASCADE) + object_id = models.PositiveIntegerField() + content_object = GenericForeignKey('content_type', 'object_id') + + +class Book(models.Model): + isbn = models.CharField(max_length=9) + name = models.CharField(max_length=255) + pages = models.IntegerField() + rating = models.FloatField() + price = models.DecimalField(decimal_places=2, max_digits=6) + authors = models.ManyToManyField(Author) + contact = models.ForeignKey(Author, models.CASCADE, related_name='book_contact_set') + publisher = models.ForeignKey(Publisher, models.CASCADE) + pubdate = models.DateField() + tags = GenericRelation(ItemTag) + + class Meta: + ordering = ('name',) + + def __str__(self): + return self.name + + +class Store(models.Model): + name = models.CharField(max_length=255) + books = models.ManyToManyField(Book) + original_opening = models.DateTimeField() + friday_night_closing = models.TimeField() + + def __str__(self): + return self.name + + +class Entries(models.Model): + EntryID = models.AutoField(primary_key=True, db_column='Entry ID') + Entry = models.CharField(unique=True, max_length=50) + Exclude = models.BooleanField(default=False) + + +class Clues(models.Model): + ID = models.AutoField(primary_key=True) + EntryID = models.ForeignKey(Entries, models.CASCADE, verbose_name='Entry', db_column='Entry ID') + Clue = models.CharField(max_length=150) + + +class WithManualPK(models.Model): + # The generic relations regression test needs two different model + # classes with the same PK value, and there are some (external) + # DB backends that don't work nicely when assigning integer to AutoField + # column (MSSQL at least). + id = models.IntegerField(primary_key=True) + + +class HardbackBook(Book): + weight = models.FloatField() + + def __str__(self): + return "%s (hardback): %s" % (self.name, self.weight) + + +# Models for ticket #21150 +class Alfa(models.Model): + name = models.CharField(max_length=10, null=True) + + +class Bravo(models.Model): + pass + + +class Charlie(models.Model): + alfa = models.ForeignKey(Alfa, models.SET_NULL, null=True) + bravo = models.ForeignKey(Bravo, models.SET_NULL, null=True) + + +class SelfRefFK(models.Model): + name = models.CharField(max_length=50) + parent = models.ForeignKey('self', models.SET_NULL, null=True, blank=True, related_name='children') diff --git a/tests/aggregation_regress/tests.py b/tests/aggregation_regress/tests.py new file mode 100644 index 00000000..becac848 --- /dev/null +++ b/tests/aggregation_regress/tests.py @@ -0,0 +1,1530 @@ +import datetime +import pickle +from decimal import Decimal +from operator import attrgetter +from unittest import mock + +from django.contrib.contenttypes.models import ContentType +from django.core.exceptions import FieldError +from django.db import connection +from django.db.models import ( + Avg, Case, Count, DecimalField, F, IntegerField, Max, Q, StdDev, Sum, + Value, Variance, When, +) +from django.test import TestCase, skipUnlessAnyDBFeature, skipUnlessDBFeature +from django.test.utils import Approximate + +from .models import ( + Alfa, Author, Book, Bravo, Charlie, Clues, Entries, HardbackBook, ItemTag, + Publisher, SelfRefFK, Store, WithManualPK, +) + + +class AggregationTests(TestCase): + + @classmethod + def setUpTestData(cls): + cls.a1 = Author.objects.create(name='Adrian Holovaty', age=34) + cls.a2 = Author.objects.create(name='Jacob Kaplan-Moss', age=35) + cls.a3 = Author.objects.create(name='Brad Dayley', age=45) + cls.a4 = Author.objects.create(name='James Bennett', age=29) + cls.a5 = Author.objects.create(name='Jeffrey Forcier', age=37) + cls.a6 = Author.objects.create(name='Paul Bissex', age=29) + cls.a7 = Author.objects.create(name='Wesley J. Chun', age=25) + cls.a8 = Author.objects.create(name='Peter Norvig', age=57) + cls.a9 = Author.objects.create(name='Stuart Russell', age=46) + cls.a1.friends.add(cls.a2, cls.a4) + cls.a2.friends.add(cls.a1, cls.a7) + cls.a4.friends.add(cls.a1) + cls.a5.friends.add(cls.a6, cls.a7) + cls.a6.friends.add(cls.a5, cls.a7) + cls.a7.friends.add(cls.a2, cls.a5, cls.a6) + cls.a8.friends.add(cls.a9) + cls.a9.friends.add(cls.a8) + + cls.p1 = Publisher.objects.create(name='Apress', num_awards=3) + cls.p2 = Publisher.objects.create(name='Sams', num_awards=1) + cls.p3 = Publisher.objects.create(name='Prentice Hall', num_awards=7) + cls.p4 = Publisher.objects.create(name='Morgan Kaufmann', num_awards=9) + cls.p5 = Publisher.objects.create(name="Jonno's House of Books", num_awards=0) + + cls.b1 = Book.objects.create( + isbn='159059725', name='The Definitive Guide to Django: Web Development Done Right', + pages=447, rating=4.5, price=Decimal('30.00'), contact=cls.a1, publisher=cls.p1, + pubdate=datetime.date(2007, 12, 6) + ) + cls.b2 = Book.objects.create( + isbn='067232959', name='Sams Teach Yourself Django in 24 Hours', + pages=528, rating=3.0, price=Decimal('23.09'), contact=cls.a3, publisher=cls.p2, + pubdate=datetime.date(2008, 3, 3) + ) + cls.b3 = Book.objects.create( + isbn='159059996', name='Practical Django Projects', + pages=300, rating=4.0, price=Decimal('29.69'), contact=cls.a4, publisher=cls.p1, + pubdate=datetime.date(2008, 6, 23) + ) + cls.b4 = Book.objects.create( + isbn='013235613', name='Python Web Development with Django', + pages=350, rating=4.0, price=Decimal('29.69'), contact=cls.a5, publisher=cls.p3, + pubdate=datetime.date(2008, 11, 3) + ) + cls.b5 = HardbackBook.objects.create( + isbn='013790395', name='Artificial Intelligence: A Modern Approach', + pages=1132, rating=4.0, price=Decimal('82.80'), contact=cls.a8, publisher=cls.p3, + pubdate=datetime.date(1995, 1, 15), weight=4.5) + cls.b6 = HardbackBook.objects.create( + isbn='155860191', name='Paradigms of Artificial Intelligence Programming: Case Studies in Common Lisp', + pages=946, rating=5.0, price=Decimal('75.00'), contact=cls.a8, publisher=cls.p4, + pubdate=datetime.date(1991, 10, 15), weight=3.7) + cls.b1.authors.add(cls.a1, cls.a2) + cls.b2.authors.add(cls.a3) + cls.b3.authors.add(cls.a4) + cls.b4.authors.add(cls.a5, cls.a6, cls.a7) + cls.b5.authors.add(cls.a8, cls.a9) + cls.b6.authors.add(cls.a8) + + s1 = Store.objects.create( + name='Amazon.com', + original_opening=datetime.datetime(1994, 4, 23, 9, 17, 42), + friday_night_closing=datetime.time(23, 59, 59) + ) + s2 = Store.objects.create( + name='Books.com', + original_opening=datetime.datetime(2001, 3, 15, 11, 23, 37), + friday_night_closing=datetime.time(23, 59, 59) + ) + s3 = Store.objects.create( + name="Mamma and Pappa's Books", + original_opening=datetime.datetime(1945, 4, 25, 16, 24, 14), + friday_night_closing=datetime.time(21, 30) + ) + s1.books.add(cls.b1, cls.b2, cls.b3, cls.b4, cls.b5, cls.b6) + s2.books.add(cls.b1, cls.b3, cls.b5, cls.b6) + s3.books.add(cls.b3, cls.b4, cls.b6) + + def assertObjectAttrs(self, obj, **kwargs): + for attr, value in kwargs.items(): + self.assertEqual(getattr(obj, attr), value) + + #def test_annotation_with_value(self): + # values = Book.objects.filter( + # name='Practical Django Projects', + # ).annotate( + # discount_price=F('price') * 2, + # ).values( + # 'discount_price', + # ).annotate(sum_discount=Sum('discount_price')) + # self.assertSequenceEqual( + # values, + # [{'discount_price': Decimal('59.38'), 'sum_discount': Decimal('59.38')}] + # ) + + def test_aggregates_in_where_clause(self): + """ + Regression test for #12822: DatabaseError: aggregates not allowed in + WHERE clause + + The subselect works and returns results equivalent to a + query with the IDs listed. + + Before the corresponding fix for this bug, this test passed in 1.1 and + failed in 1.2-beta (trunk). + """ + qs = Book.objects.values('contact').annotate(Max('id')) + qs = qs.order_by('contact').values_list('id__max', flat=True) + # don't do anything with the queryset (qs) before including it as a + # subquery + books = Book.objects.order_by('id') + qs1 = books.filter(id__in=qs) + qs2 = books.filter(id__in=list(qs)) + self.assertEqual(list(qs1), list(qs2)) + + def test_aggregates_in_where_clause_pre_eval(self): + """ + Regression test for #12822: DatabaseError: aggregates not allowed in + WHERE clause + + Same as the above test, but evaluates the queryset for the subquery + before it's used as a subquery. + + Before the corresponding fix for this bug, this test failed in both + 1.1 and 1.2-beta (trunk). + """ + qs = Book.objects.values('contact').annotate(Max('id')) + qs = qs.order_by('contact').values_list('id__max', flat=True) + # force the queryset (qs) for the subquery to be evaluated in its + # current state + list(qs) + books = Book.objects.order_by('id') + qs1 = books.filter(id__in=qs) + qs2 = books.filter(id__in=list(qs)) + self.assertEqual(list(qs1), list(qs2)) + + @skipUnlessDBFeature('supports_subqueries_in_group_by') + def test_annotate_with_extra(self): + """ + Regression test for #11916: Extra params + aggregation creates + incorrect SQL. + """ + # Oracle doesn't support subqueries in group by clause + shortest_book_sql = """ + SELECT name + FROM aggregation_regress_book b + WHERE b.publisher_id = aggregation_regress_publisher.id + ORDER BY b.pages + LIMIT 1 + """ + # tests that this query does not raise a DatabaseError due to the full + # subselect being (erroneously) added to the GROUP BY parameters + qs = Publisher.objects.extra(select={ + 'name_of_shortest_book': shortest_book_sql, + }).annotate(total_books=Count('book')) + # force execution of the query + list(qs) + + def test_aggregate(self): + # Ordering requests are ignored + self.assertEqual( + Author.objects.order_by("name").aggregate(Avg("age")), + {"age__avg": Approximate(37.444, places=1)} + ) + + # Implicit ordering is also ignored + self.assertEqual( + Book.objects.aggregate(Sum("pages")), + {"pages__sum": 3703}, + ) + + # Baseline results + self.assertEqual( + Book.objects.aggregate(Sum('pages'), Avg('pages')), + {'pages__sum': 3703, 'pages__avg': Approximate(617.166, places=2)} + ) + + # Empty values query doesn't affect grouping or results + self.assertEqual( + Book.objects.values().aggregate(Sum('pages'), Avg('pages')), + {'pages__sum': 3703, 'pages__avg': Approximate(617.166, places=2)} + ) + + # Aggregate overrides extra selected column + self.assertEqual( + Book.objects.extra(select={'price_per_page': 'price / pages'}).aggregate(Sum('pages')), + {'pages__sum': 3703} + ) + + def test_annotation(self): + # Annotations get combined with extra select clauses + obj = Book.objects.annotate(mean_auth_age=Avg("authors__age")).extra( + select={"manufacture_cost": "price * .5"}).get(pk=self.b2.pk) + self.assertObjectAttrs( + obj, + contact_id=self.a3.id, + isbn='067232959', + mean_auth_age=45.0, + name='Sams Teach Yourself Django in 24 Hours', + pages=528, + price=Decimal("23.09"), + pubdate=datetime.date(2008, 3, 3), + publisher_id=self.p2.id, + rating=3.0 + ) + # Different DB backends return different types for the extra select computation + self.assertIn(obj.manufacture_cost, (11.545, Decimal('11.545'))) + + # Order of the annotate/extra in the query doesn't matter + obj = Book.objects.extra(select={'manufacture_cost': 'price * .5'}).annotate( + mean_auth_age=Avg('authors__age')).get(pk=self.b2.pk) + self.assertObjectAttrs( + obj, + contact_id=self.a3.id, + isbn='067232959', + mean_auth_age=45.0, + name='Sams Teach Yourself Django in 24 Hours', + pages=528, + price=Decimal("23.09"), + pubdate=datetime.date(2008, 3, 3), + publisher_id=self.p2.id, + rating=3.0 + ) + # Different DB backends return different types for the extra select computation + self.assertIn(obj.manufacture_cost, (11.545, Decimal('11.545'))) + + # Values queries can be combined with annotate and extra + obj = Book.objects.annotate(mean_auth_age=Avg('authors__age')).extra( + select={'manufacture_cost': 'price * .5'}).values().get(pk=self.b2.pk) + manufacture_cost = obj['manufacture_cost'] + self.assertIn(manufacture_cost, (11.545, Decimal('11.545'))) + del obj['manufacture_cost'] + self.assertEqual(obj, { + 'id': self.b2.id, + 'contact_id': self.a3.id, + 'isbn': '067232959', + 'mean_auth_age': 45.0, + 'name': 'Sams Teach Yourself Django in 24 Hours', + 'pages': 528, + 'price': Decimal('23.09'), + 'pubdate': datetime.date(2008, 3, 3), + 'publisher_id': self.p2.id, + 'rating': 3.0, + }) + + # The order of the (empty) values, annotate and extra clauses doesn't + # matter + obj = Book.objects.values().annotate(mean_auth_age=Avg('authors__age')).extra( + select={'manufacture_cost': 'price * .5'}).get(pk=self.b2.pk) + manufacture_cost = obj['manufacture_cost'] + self.assertIn(manufacture_cost, (11.545, Decimal('11.545'))) + del obj['manufacture_cost'] + self.assertEqual(obj, { + 'id': self.b2.id, + 'contact_id': self.a3.id, + 'isbn': '067232959', + 'mean_auth_age': 45.0, + 'name': 'Sams Teach Yourself Django in 24 Hours', + 'pages': 528, + 'price': Decimal('23.09'), + 'pubdate': datetime.date(2008, 3, 3), + 'publisher_id': self.p2.id, + 'rating': 3.0 + }) + + # If the annotation precedes the values clause, it won't be included + # unless it is explicitly named + obj = Book.objects.annotate(mean_auth_age=Avg('authors__age')).extra( + select={'price_per_page': 'price / pages'}).values('name').get(pk=self.b1.pk) + self.assertEqual(obj, { + "name": 'The Definitive Guide to Django: Web Development Done Right', + }) + + obj = Book.objects.annotate(mean_auth_age=Avg('authors__age')).extra( + select={'price_per_page': 'price / pages'}).values('name', 'mean_auth_age').get(pk=self.b1.pk) + self.assertEqual(obj, { + 'mean_auth_age': 34.5, + 'name': 'The Definitive Guide to Django: Web Development Done Right', + }) + + # If an annotation isn't included in the values, it can still be used + # in a filter + qs = Book.objects.annotate(n_authors=Count('authors')).values('name').filter(n_authors__gt=2) + self.assertSequenceEqual( + qs, [ + {"name": 'Python Web Development with Django'} + ], + ) + + # The annotations are added to values output if values() precedes + # annotate() + obj = Book.objects.values('name').annotate(mean_auth_age=Avg('authors__age')).extra( + select={'price_per_page': 'price / pages'}).get(pk=self.b1.pk) + self.assertEqual(obj, { + 'mean_auth_age': 34.5, + 'name': 'The Definitive Guide to Django: Web Development Done Right', + }) + + # All of the objects are getting counted (allow_nulls) and that values + # respects the amount of objects + self.assertEqual( + len(Author.objects.annotate(Avg('friends__age')).values()), + 9 + ) + + # Consecutive calls to annotate accumulate in the query + qs = ( + Book.objects + .values('price') + .annotate(oldest=Max('authors__age')) + .order_by('oldest', 'price') + .annotate(Max('publisher__num_awards')) + ) + self.assertSequenceEqual( + qs, [ + {'price': Decimal("30"), 'oldest': 35, 'publisher__num_awards__max': 3}, + {'price': Decimal("29.69"), 'oldest': 37, 'publisher__num_awards__max': 7}, + {'price': Decimal("23.09"), 'oldest': 45, 'publisher__num_awards__max': 1}, + {'price': Decimal("75"), 'oldest': 57, 'publisher__num_awards__max': 9}, + {'price': Decimal("82.8"), 'oldest': 57, 'publisher__num_awards__max': 7} + ], + ) + + def test_aggregate_annotation(self): + # Aggregates can be composed over annotations. + # The return type is derived from the composed aggregate + vals = ( + Book.objects + .all() + .annotate(num_authors=Count('authors__id')) + .aggregate(Max('pages'), Max('price'), Sum('num_authors'), Avg('num_authors')) + ) + self.assertEqual(vals, { + 'num_authors__sum': 10, + 'num_authors__avg': Approximate(1.666, places=2), + 'pages__max': 1132, + 'price__max': Decimal("82.80") + }) + + # Regression for #15624 - Missing SELECT columns when using values, annotate + # and aggregate in a single query + self.assertEqual( + Book.objects.annotate(c=Count('authors')).values('c').aggregate(Max('c')), + {'c__max': 3} + ) + + def test_conditional_aggreate(self): + # Conditional aggregation of a grouped queryset. + self.assertEqual( + Book.objects.annotate(c=Count('authors')).values('pk').aggregate(test=Sum( + Case(When(c__gt=1, then=1), output_field=IntegerField()) + ))['test'], + 3 + ) + + def test_sliced_conditional_aggregate(self): + self.assertEqual( + Author.objects.all()[:5].aggregate(test=Sum(Case( + When(age__lte=35, then=1), output_field=IntegerField() + )))['test'], + 3 + ) + + #def test_annotated_conditional_aggregate(self): + # annotated_qs = Book.objects.annotate(discount_price=F('price') * 0.75) + # self.assertAlmostEqual( + # annotated_qs.aggregate(test=Avg(Case( + # When(pages__lt=400, then='discount_price'), + # output_field=DecimalField() + # )))['test'], + # 22.27, places=2 + # ) + + def test_distinct_conditional_aggregate(self): + self.assertEqual( + Book.objects.distinct().aggregate(test=Avg(Case( + When(price=Decimal('29.69'), then='pages'), + output_field=IntegerField() + )))['test'], + 325 + ) + + def test_conditional_aggregate_on_complex_condition(self): + self.assertEqual( + Book.objects.distinct().aggregate(test=Avg(Case( + When(Q(price__gte=Decimal('29')) & Q(price__lt=Decimal('30')), then='pages'), + output_field=IntegerField() + )))['test'], + 325 + ) + + def test_decimal_aggregate_annotation_filter(self): + """ + Filtering on an aggregate annotation with Decimal values should work. + Requires special handling on SQLite (#18247). + """ + self.assertEqual( + len(Author.objects.annotate(sum=Sum('book_contact_set__price')).filter(sum__gt=Decimal(40))), + 1 + ) + self.assertEqual( + len(Author.objects.annotate(sum=Sum('book_contact_set__price')).filter(sum__lte=Decimal(40))), + 4 + ) + + def test_field_error(self): + # Bad field requests in aggregates are caught and reported + msg = ( + "Cannot resolve keyword 'foo' into field. Choices are: authors, " + "contact, contact_id, hardbackbook, id, isbn, name, pages, price, " + "pubdate, publisher, publisher_id, rating, store, tags" + ) + with self.assertRaisesMessage(FieldError, msg): + Book.objects.all().aggregate(num_authors=Count('foo')) + + with self.assertRaisesMessage(FieldError, msg): + Book.objects.all().annotate(num_authors=Count('foo')) + + msg = ( + "Cannot resolve keyword 'foo' into field. Choices are: authors, " + "contact, contact_id, hardbackbook, id, isbn, name, num_authors, " + "pages, price, pubdate, publisher, publisher_id, rating, store, tags" + ) + with self.assertRaisesMessage(FieldError, msg): + Book.objects.all().annotate(num_authors=Count('authors__id')).aggregate(Max('foo')) + + def test_more(self): + # Old-style count aggregations can be mixed with new-style + self.assertEqual( + Book.objects.annotate(num_authors=Count('authors')).count(), + 6 + ) + + # Non-ordinal, non-computed Aggregates over annotations correctly + # inherit the annotation's internal type if the annotation is ordinal + # or computed + vals = Book.objects.annotate(num_authors=Count('authors')).aggregate(Max('num_authors')) + self.assertEqual( + vals, + {'num_authors__max': 3} + ) + + vals = Publisher.objects.annotate(avg_price=Avg('book__price')).aggregate(Max('avg_price')) + self.assertEqual( + vals, + {'avg_price__max': 75.0} + ) + + # Aliases are quoted to protected aliases that might be reserved names + vals = Book.objects.aggregate(number=Max('pages'), select=Max('pages')) + self.assertEqual( + vals, + {'number': 1132, 'select': 1132} + ) + + # Regression for #10064: select_related() plays nice with aggregates + obj = Book.objects.select_related('publisher').annotate( + num_authors=Count('authors')).values().get(isbn='013790395') + self.assertEqual(obj, { + 'contact_id': self.a8.id, + 'id': self.b5.id, + 'isbn': '013790395', + 'name': 'Artificial Intelligence: A Modern Approach', + 'num_authors': 2, + 'pages': 1132, + 'price': Decimal("82.8"), + 'pubdate': datetime.date(1995, 1, 15), + 'publisher_id': self.p3.id, + 'rating': 4.0, + }) + + # Regression for #10010: exclude on an aggregate field is correctly + # negated + self.assertEqual( + len(Book.objects.annotate(num_authors=Count('authors'))), + 6 + ) + self.assertEqual( + len(Book.objects.annotate(num_authors=Count('authors')).filter(num_authors__gt=2)), + 1 + ) + self.assertEqual( + len(Book.objects.annotate(num_authors=Count('authors')).exclude(num_authors__gt=2)), + 5 + ) + + self.assertEqual( + len( + Book.objects + .annotate(num_authors=Count('authors')) + .filter(num_authors__lt=3) + .exclude(num_authors__lt=2) + ), + 2 + ) + self.assertEqual( + len( + Book.objects + .annotate(num_authors=Count('authors')) + .exclude(num_authors__lt=2) + .filter(num_authors__lt=3) + ), + 2 + ) + + def test_aggregate_fexpr(self): + # Aggregates can be used with F() expressions + # ... where the F() is pushed into the HAVING clause + qs = ( + Publisher.objects + .annotate(num_books=Count('book')) + .filter(num_books__lt=F('num_awards') / 2) + .order_by('name') + .values('name', 'num_books', 'num_awards') + ) + self.assertSequenceEqual( + qs, [ + {'num_books': 1, 'name': 'Morgan Kaufmann', 'num_awards': 9}, + {'num_books': 2, 'name': 'Prentice Hall', 'num_awards': 7} + ], + ) + + qs = ( + Publisher.objects + .annotate(num_books=Count('book')) + .exclude(num_books__lt=F('num_awards') / 2) + .order_by('name') + .values('name', 'num_books', 'num_awards') + ) + self.assertSequenceEqual( + qs, [ + {'num_books': 2, 'name': 'Apress', 'num_awards': 3}, + {'num_books': 0, 'name': "Jonno's House of Books", 'num_awards': 0}, + {'num_books': 1, 'name': 'Sams', 'num_awards': 1} + ], + ) + + # ... and where the F() references an aggregate + qs = ( + Publisher.objects + .annotate(num_books=Count('book')) + .filter(num_awards__gt=2 * F('num_books')) + .order_by('name') + .values('name', 'num_books', 'num_awards') + ) + self.assertSequenceEqual( + qs, [ + {'num_books': 1, 'name': 'Morgan Kaufmann', 'num_awards': 9}, + {'num_books': 2, 'name': 'Prentice Hall', 'num_awards': 7} + ], + ) + + qs = ( + Publisher.objects + .annotate(num_books=Count('book')) + .exclude(num_books__lt=F('num_awards') / 2) + .order_by('name') + .values('name', 'num_books', 'num_awards') + ) + self.assertSequenceEqual( + qs, [ + {'num_books': 2, 'name': 'Apress', 'num_awards': 3}, + {'num_books': 0, 'name': "Jonno's House of Books", 'num_awards': 0}, + {'num_books': 1, 'name': 'Sams', 'num_awards': 1} + ], + ) + + def test_db_col_table(self): + # Tests on fields with non-default table and column names. + qs = ( + Clues.objects + .values('EntryID__Entry') + .annotate(Appearances=Count('EntryID'), Distinct_Clues=Count('Clue', distinct=True)) + ) + self.assertQuerysetEqual(qs, []) + + qs = Entries.objects.annotate(clue_count=Count('clues__ID')) + self.assertQuerysetEqual(qs, []) + + def test_boolean_conversion(self): + # Aggregates mixed up ordering of columns for backend's convert_values + # method. Refs #21126. + e = Entries.objects.create(Entry='foo') + c = Clues.objects.create(EntryID=e, Clue='bar') + qs = Clues.objects.select_related('EntryID').annotate(Count('ID')) + self.assertSequenceEqual(qs, [c]) + self.assertEqual(qs[0].EntryID, e) + self.assertIs(qs[0].EntryID.Exclude, False) + + def test_empty(self): + # Regression for #10089: Check handling of empty result sets with + # aggregates + self.assertEqual( + Book.objects.filter(id__in=[]).count(), + 0 + ) + + vals = ( + Book.objects + .filter(id__in=[]) + .aggregate( + num_authors=Count('authors'), + avg_authors=Avg('authors'), + max_authors=Max('authors'), + max_price=Max('price'), + max_rating=Max('rating'), + ) + ) + self.assertEqual( + vals, + {'max_authors': None, 'max_rating': None, 'num_authors': 0, 'avg_authors': None, 'max_price': None} + ) + + qs = ( + Publisher.objects + .filter(name="Jonno's House of Books") + .annotate( + num_authors=Count('book__authors'), + avg_authors=Avg('book__authors'), + max_authors=Max('book__authors'), + max_price=Max('book__price'), + max_rating=Max('book__rating'), + ).values() + ) + self.assertSequenceEqual( + qs, + [{ + 'max_authors': None, + 'name': "Jonno's House of Books", + 'num_awards': 0, + 'max_price': None, + 'num_authors': 0, + 'max_rating': None, + 'id': self.p5.id, + 'avg_authors': None, + }], + ) + + def test_more_more(self): + # Regression for #10113 - Fields mentioned in order_by() must be + # included in the GROUP BY. This only becomes a problem when the + # order_by introduces a new join. + self.assertQuerysetEqual( + Book.objects.annotate(num_authors=Count('authors')).order_by('publisher__name', 'name'), [ + "Practical Django Projects", + "The Definitive Guide to Django: Web Development Done Right", + "Paradigms of Artificial Intelligence Programming: Case Studies in Common Lisp", + "Artificial Intelligence: A Modern Approach", + "Python Web Development with Django", + "Sams Teach Yourself Django in 24 Hours", + ], + lambda b: b.name + ) + + # Regression for #10127 - Empty select_related() works with annotate + qs = Book.objects.filter(rating__lt=4.5).select_related().annotate(Avg('authors__age')) + self.assertQuerysetEqual( + qs, + [ + ('Artificial Intelligence: A Modern Approach', 51.5, 'Prentice Hall', 'Peter Norvig'), + ('Practical Django Projects', 29.0, 'Apress', 'James Bennett'), + ( + 'Python Web Development with Django', + Approximate(30.333, places=2), + 'Prentice Hall', + 'Jeffrey Forcier', + ), + ('Sams Teach Yourself Django in 24 Hours', 45.0, 'Sams', 'Brad Dayley') + ], + lambda b: (b.name, b.authors__age__avg, b.publisher.name, b.contact.name) + ) + + # Regression for #10132 - If the values() clause only mentioned extra + # (select=) columns, those columns are used for grouping + qs = Book.objects.extra(select={'pub': 'publisher_id'}).values('pub').annotate(Count('id')).order_by('pub') + self.assertSequenceEqual( + qs, [ + {'pub': self.b1.id, 'id__count': 2}, + {'pub': self.b2.id, 'id__count': 1}, + {'pub': self.b3.id, 'id__count': 2}, + {'pub': self.b4.id, 'id__count': 1} + ], + ) + + qs = ( + Book.objects + .extra(select={'pub': 'publisher_id', 'foo': 'pages'}) + .values('pub') + .annotate(Count('id')) + .order_by('pub') + ) + self.assertSequenceEqual( + qs, [ + {'pub': self.p1.id, 'id__count': 2}, + {'pub': self.p2.id, 'id__count': 1}, + {'pub': self.p3.id, 'id__count': 2}, + {'pub': self.p4.id, 'id__count': 1} + ], + ) + + # Regression for #10182 - Queries with aggregate calls are correctly + # realiased when used in a subquery + ids = ( + Book.objects + .filter(pages__gt=100) + .annotate(n_authors=Count('authors')) + .filter(n_authors__gt=2) + .order_by('n_authors') + ) + self.assertQuerysetEqual( + Book.objects.filter(id__in=ids), [ + "Python Web Development with Django", + ], + lambda b: b.name + ) + + # Regression for #15709 - Ensure each group_by field only exists once + # per query + qstr = str(Book.objects.values('publisher').annotate(max_pages=Max('pages')).order_by().query) + # There is just one GROUP BY clause (zero commas means at most one clause). + self.assertEqual(qstr[qstr.index('GROUP BY'):].count(', '), 0) + + def test_duplicate_alias(self): + # Regression for #11256 - duplicating a default alias raises ValueError. + msg = ( + "The named annotation 'authors__age__avg' conflicts with " + "the default name for another annotation." + ) + with self.assertRaisesMessage(ValueError, msg): + Book.objects.all().annotate(Avg('authors__age'), authors__age__avg=Avg('authors__age')) + + def test_field_name_conflict(self): + # Regression for #11256 - providing an aggregate name + # that conflicts with a field name on the model raises ValueError + msg = "The annotation 'age' conflicts with a field on the model." + with self.assertRaisesMessage(ValueError, msg): + Author.objects.annotate(age=Avg('friends__age')) + + def test_m2m_name_conflict(self): + # Regression for #11256 - providing an aggregate name + # that conflicts with an m2m name on the model raises ValueError + msg = "The annotation 'friends' conflicts with a field on the model." + with self.assertRaisesMessage(ValueError, msg): + Author.objects.annotate(friends=Count('friends')) + + def test_fk_attname_conflict(self): + msg = "The annotation 'contact_id' conflicts with a field on the model." + with self.assertRaisesMessage(ValueError, msg): + Book.objects.annotate(contact_id=F('publisher_id')) + + def test_values_queryset_non_conflict(self): + # Regression for #14707 -- If you're using a values query set, some potential conflicts are avoided. + + # age is a field on Author, so it shouldn't be allowed as an aggregate. + # But age isn't included in values(), so it is. + results = Author.objects.values('name').annotate(age=Count('book_contact_set')).order_by('name') + self.assertEqual(len(results), 9) + self.assertEqual(results[0]['name'], 'Adrian Holovaty') + self.assertEqual(results[0]['age'], 1) + + # Same problem, but aggregating over m2m fields + results = Author.objects.values('name').annotate(age=Avg('friends__age')).order_by('name') + self.assertEqual(len(results), 9) + self.assertEqual(results[0]['name'], 'Adrian Holovaty') + self.assertEqual(results[0]['age'], 32.0) + + # Same problem, but colliding with an m2m field + results = Author.objects.values('name').annotate(friends=Count('friends')).order_by('name') + self.assertEqual(len(results), 9) + self.assertEqual(results[0]['name'], 'Adrian Holovaty') + self.assertEqual(results[0]['friends'], 2) + + def test_reverse_relation_name_conflict(self): + # Regression for #11256 - providing an aggregate name + # that conflicts with a reverse-related name on the model raises ValueError + msg = "The annotation 'book_contact_set' conflicts with a field on the model." + with self.assertRaisesMessage(ValueError, msg): + Author.objects.annotate(book_contact_set=Avg('friends__age')) + + def test_pickle(self): + # Regression for #10197 -- Queries with aggregates can be pickled. + # First check that pickling is possible at all. No crash = success + qs = Book.objects.annotate(num_authors=Count('authors')) + pickle.dumps(qs) + + # Then check that the round trip works. + query = qs.query.get_compiler(qs.db).as_sql()[0] + qs2 = pickle.loads(pickle.dumps(qs)) + self.assertEqual( + qs2.query.get_compiler(qs2.db).as_sql()[0], + query, + ) + + def test_more_more_more(self): + # Regression for #10199 - Aggregate calls clone the original query so + # the original query can still be used + books = Book.objects.all() + books.aggregate(Avg("authors__age")) + self.assertQuerysetEqual( + books.all(), [ + 'Artificial Intelligence: A Modern Approach', + 'Paradigms of Artificial Intelligence Programming: Case Studies in Common Lisp', + 'Practical Django Projects', + 'Python Web Development with Django', + 'Sams Teach Yourself Django in 24 Hours', + 'The Definitive Guide to Django: Web Development Done Right' + ], + lambda b: b.name + ) + + # Regression for #10248 - Annotations work with dates() + qs = Book.objects.annotate(num_authors=Count('authors')).filter(num_authors=2).dates('pubdate', 'day') + self.assertSequenceEqual( + qs, [ + datetime.date(1995, 1, 15), + datetime.date(2007, 12, 6), + ], + ) + + # Regression for #10290 - extra selects with parameters can be used for + # grouping. + qs = ( + Book.objects + .annotate(mean_auth_age=Avg('authors__age')) + .extra(select={'sheets': '(pages + %s) / %s'}, select_params=[1, 2]) + .order_by('sheets') + .values('sheets') + ) + self.assertQuerysetEqual( + qs, [ + 150, + 175, + 224, + 264, + 473, + 566 + ], + lambda b: int(b["sheets"]) + ) + + # Regression for 10425 - annotations don't get in the way of a count() + # clause + self.assertEqual( + Book.objects.values('publisher').annotate(Count('publisher')).count(), + 4 + ) + self.assertEqual( + Book.objects.annotate(Count('publisher')).values('publisher').count(), + 6 + ) + + # Note: intentionally no order_by(), that case needs tests, too. + publishers = Publisher.objects.filter(id__in=[1, 2]) + self.assertEqual( + sorted(p.name for p in publishers), + [ + "Apress", + "Sams" + ] + ) + + publishers = publishers.annotate(n_books=Count("book")) + sorted_publishers = sorted(publishers, key=lambda x: x.name) + self.assertEqual( + sorted_publishers[0].n_books, + 2 + ) + self.assertEqual( + sorted_publishers[1].n_books, + 1 + ) + + self.assertEqual( + sorted(p.name for p in publishers), + [ + "Apress", + "Sams" + ] + ) + + books = Book.objects.filter(publisher__in=publishers) + self.assertQuerysetEqual( + books, [ + "Practical Django Projects", + "Sams Teach Yourself Django in 24 Hours", + "The Definitive Guide to Django: Web Development Done Right", + ], + lambda b: b.name + ) + self.assertEqual( + sorted(p.name for p in publishers), + [ + "Apress", + "Sams" + ] + ) + + # Regression for 10666 - inherited fields work with annotations and + # aggregations + self.assertEqual( + HardbackBook.objects.aggregate(n_pages=Sum('book_ptr__pages')), + {'n_pages': 2078} + ) + + self.assertEqual( + HardbackBook.objects.aggregate(n_pages=Sum('pages')), + {'n_pages': 2078}, + ) + + qs = HardbackBook.objects.annotate(n_authors=Count('book_ptr__authors')).values('name', 'n_authors') + self.assertSequenceEqual( + qs, + [ + {'n_authors': 2, 'name': 'Artificial Intelligence: A Modern Approach'}, + { + 'n_authors': 1, + 'name': 'Paradigms of Artificial Intelligence Programming: Case Studies in Common Lisp' + } + ], + ) + + qs = HardbackBook.objects.annotate(n_authors=Count('authors')).values('name', 'n_authors') + self.assertSequenceEqual( + qs, + [ + {'n_authors': 2, 'name': 'Artificial Intelligence: A Modern Approach'}, + { + 'n_authors': 1, + 'name': 'Paradigms of Artificial Intelligence Programming: Case Studies in Common Lisp' + } + ], + ) + + # Regression for #10766 - Shouldn't be able to reference an aggregate + # fields in an aggregate() call. + msg = "Cannot compute Avg('mean_age'): 'mean_age' is an aggregate" + with self.assertRaisesMessage(FieldError, msg): + Book.objects.annotate(mean_age=Avg('authors__age')).annotate(Avg('mean_age')) + + def test_empty_filter_count(self): + self.assertEqual( + Author.objects.filter(id__in=[]).annotate(Count("friends")).count(), + 0 + ) + + def test_empty_filter_aggregate(self): + self.assertEqual( + Author.objects.filter(id__in=[]).annotate(Count("friends")).aggregate(Count("pk")), + {"pk__count": None} + ) + + def test_none_call_before_aggregate(self): + # Regression for #11789 + self.assertEqual( + Author.objects.none().aggregate(Avg('age')), + {'age__avg': None} + ) + + def test_annotate_and_join(self): + self.assertEqual( + Author.objects.annotate(c=Count("friends__name")).exclude(friends__name="Joe").count(), + Author.objects.count() + ) + + def test_f_expression_annotation(self): + # Books with less than 200 pages per author. + qs = Book.objects.values("name").annotate( + n_authors=Count("authors") + ).filter( + pages__lt=F("n_authors") * 200 + ).values_list("pk") + self.assertQuerysetEqual( + Book.objects.filter(pk__in=qs), [ + "Python Web Development with Django" + ], + attrgetter("name") + ) + + def test_values_annotate_values(self): + qs = Book.objects.values("name").annotate( + n_authors=Count("authors") + ).values_list("pk", flat=True) + self.assertEqual(list(qs), list(Book.objects.values_list("pk", flat=True))) + + def test_having_group_by(self): + # When a field occurs on the LHS of a HAVING clause that it + # appears correctly in the GROUP BY clause + qs = Book.objects.values_list("name").annotate( + n_authors=Count("authors") + ).filter( + pages__gt=F("n_authors") + ).values_list("name", flat=True) + # Results should be the same, all Books have more pages than authors + self.assertEqual( + list(qs), list(Book.objects.values_list("name", flat=True)) + ) + + def test_values_list_annotation_args_ordering(self): + """ + Annotate *args ordering should be preserved in values_list results. + **kwargs comes after *args. + Regression test for #23659. + """ + books = Book.objects.values_list("publisher__name").annotate( + Count("id"), Avg("price"), Avg("authors__age"), avg_pgs=Avg("pages") + ).order_by("-publisher__name") + self.assertEqual(books[0], ('Sams', 1, 23.09, 45.0, 528.0)) + + def test_annotation_disjunction(self): + qs = Book.objects.annotate(n_authors=Count("authors")).filter( + Q(n_authors=2) | Q(name="Python Web Development with Django") + ) + self.assertQuerysetEqual( + qs, [ + "Artificial Intelligence: A Modern Approach", + "Python Web Development with Django", + "The Definitive Guide to Django: Web Development Done Right", + ], + attrgetter("name") + ) + + qs = ( + Book.objects + .annotate(n_authors=Count("authors")) + .filter( + Q(name="The Definitive Guide to Django: Web Development Done Right") | + (Q(name="Artificial Intelligence: A Modern Approach") & Q(n_authors=3)) + ) + ) + self.assertQuerysetEqual( + qs, + [ + "The Definitive Guide to Django: Web Development Done Right", + ], + attrgetter("name") + ) + + qs = Publisher.objects.annotate( + rating_sum=Sum("book__rating"), + book_count=Count("book") + ).filter( + Q(rating_sum__gt=5.5) | Q(rating_sum__isnull=True) + ).order_by('pk') + self.assertQuerysetEqual( + qs, [ + "Apress", + "Prentice Hall", + "Jonno's House of Books", + ], + attrgetter("name") + ) + + qs = Publisher.objects.annotate( + rating_sum=Sum("book__rating"), + book_count=Count("book") + ).filter( + Q(rating_sum__gt=F("book_count")) | Q(rating_sum=None) + ).order_by("num_awards") + self.assertQuerysetEqual( + qs, [ + "Jonno's House of Books", + "Sams", + "Apress", + "Prentice Hall", + "Morgan Kaufmann" + ], + attrgetter("name") + ) + + def test_quoting_aggregate_order_by(self): + qs = Book.objects.filter( + name="Python Web Development with Django" + ).annotate( + authorCount=Count("authors") + ).order_by("authorCount") + self.assertQuerysetEqual( + qs, [ + ("Python Web Development with Django", 3), + ], + lambda b: (b.name, b.authorCount) + ) + + @skipUnlessDBFeature('supports_stddev') + def test_stddev(self): + self.assertEqual( + Book.objects.aggregate(StdDev('pages')), + {'pages__stddev': Approximate(311.46, 1)} + ) + + self.assertEqual( + Book.objects.aggregate(StdDev('rating')), + {'rating__stddev': Approximate(0.60, 1)} + ) + + self.assertEqual( + Book.objects.aggregate(StdDev('price')), + {'price__stddev': Approximate(24.16, 2)} + ) + + self.assertEqual( + Book.objects.aggregate(StdDev('pages', sample=True)), + {'pages__stddev': Approximate(341.19, 2)} + ) + + self.assertEqual( + Book.objects.aggregate(StdDev('rating', sample=True)), + {'rating__stddev': Approximate(0.66, 2)} + ) + + self.assertEqual( + Book.objects.aggregate(StdDev('price', sample=True)), + {'price__stddev': Approximate(26.46, 1)} + ) + + self.assertEqual( + Book.objects.aggregate(Variance('pages')), + {'pages__variance': Approximate(97010.80, 1)} + ) + + self.assertEqual( + Book.objects.aggregate(Variance('rating')), + {'rating__variance': Approximate(0.36, 1)} + ) + + self.assertEqual( + Book.objects.aggregate(Variance('price')), + {'price__variance': Approximate(583.77, 1)} + ) + + self.assertEqual( + Book.objects.aggregate(Variance('pages', sample=True)), + {'pages__variance': Approximate(116412.96, 1)} + ) + + self.assertEqual( + Book.objects.aggregate(Variance('rating', sample=True)), + {'rating__variance': Approximate(0.44, 2)} + ) + + self.assertEqual( + Book.objects.aggregate(Variance('price', sample=True)), + {'price__variance': Approximate(700.53, 2)} + ) + + def test_filtering_by_annotation_name(self): + # Regression test for #14476 + + # The name of the explicitly provided annotation name in this case + # poses no problem + qs = Author.objects.annotate(book_cnt=Count('book')).filter(book_cnt=2).order_by('name') + self.assertQuerysetEqual( + qs, + ['Peter Norvig'], + lambda b: b.name + ) + # Neither in this case + qs = Author.objects.annotate(book_count=Count('book')).filter(book_count=2).order_by('name') + self.assertQuerysetEqual( + qs, + ['Peter Norvig'], + lambda b: b.name + ) + # This case used to fail because the ORM couldn't resolve the + # automatically generated annotation name `book__count` + qs = Author.objects.annotate(Count('book')).filter(book__count=2).order_by('name') + self.assertQuerysetEqual( + qs, + ['Peter Norvig'], + lambda b: b.name + ) + # Referencing the auto-generated name in an aggregate() also works. + self.assertEqual( + Author.objects.annotate(Count('book')).aggregate(Max('book__count')), + {'book__count__max': 2} + ) + + def test_annotate_joins(self): + """ + The base table's join isn't promoted to LOUTER. This could + cause the query generation to fail if there is an exclude() for fk-field + in the query, too. Refs #19087. + """ + qs = Book.objects.annotate(n=Count('pk')) + self.assertIs(qs.query.alias_map['aggregation_regress_book'].join_type, None) + # The query executes without problems. + self.assertEqual(len(qs.exclude(publisher=-1)), 6) + + @skipUnlessAnyDBFeature('allows_group_by_pk', 'allows_group_by_selected_pks') + def test_aggregate_duplicate_columns(self): + # Regression test for #17144 + + results = Author.objects.annotate(num_contacts=Count('book_contact_set')) + + # There should only be one GROUP BY clause, for the `id` column. + # `name` and `age` should not be grouped on. + _, _, group_by = results.query.get_compiler(using='default').pre_sql_setup() + self.assertEqual(len(group_by), 1) + self.assertIn('id', group_by[0][0]) + self.assertNotIn('name', group_by[0][0]) + self.assertNotIn('age', group_by[0][0]) + self.assertEqual( + [(a.name, a.num_contacts) for a in results.order_by('name')], + [ + ('Adrian Holovaty', 1), + ('Brad Dayley', 1), + ('Jacob Kaplan-Moss', 0), + ('James Bennett', 1), + ('Jeffrey Forcier', 1), + ('Paul Bissex', 0), + ('Peter Norvig', 2), + ('Stuart Russell', 0), + ('Wesley J. Chun', 0), + ] + ) + + @skipUnlessAnyDBFeature('allows_group_by_pk', 'allows_group_by_selected_pks') + def test_aggregate_duplicate_columns_only(self): + # Works with only() too. + results = Author.objects.only('id', 'name').annotate(num_contacts=Count('book_contact_set')) + _, _, grouping = results.query.get_compiler(using='default').pre_sql_setup() + self.assertEqual(len(grouping), 1) + self.assertIn('id', grouping[0][0]) + self.assertNotIn('name', grouping[0][0]) + self.assertNotIn('age', grouping[0][0]) + self.assertEqual( + [(a.name, a.num_contacts) for a in results.order_by('name')], + [ + ('Adrian Holovaty', 1), + ('Brad Dayley', 1), + ('Jacob Kaplan-Moss', 0), + ('James Bennett', 1), + ('Jeffrey Forcier', 1), + ('Paul Bissex', 0), + ('Peter Norvig', 2), + ('Stuart Russell', 0), + ('Wesley J. Chun', 0), + ] + ) + + @skipUnlessAnyDBFeature('allows_group_by_pk', 'allows_group_by_selected_pks') + def test_aggregate_duplicate_columns_select_related(self): + # And select_related() + results = Book.objects.select_related('contact').annotate( + num_authors=Count('authors')) + _, _, grouping = results.query.get_compiler(using='default').pre_sql_setup() + # In the case of `group_by_selected_pks` we also group by contact.id because of the select_related. + self.assertEqual(len(grouping), 1 if connection.features.allows_group_by_pk else 2) + self.assertIn('id', grouping[0][0]) + self.assertNotIn('name', grouping[0][0]) + self.assertNotIn('contact', grouping[0][0]) + self.assertEqual( + [(b.name, b.num_authors) for b in results.order_by('name')], + [ + ('Artificial Intelligence: A Modern Approach', 2), + ('Paradigms of Artificial Intelligence Programming: Case Studies in Common Lisp', 1), + ('Practical Django Projects', 1), + ('Python Web Development with Django', 3), + ('Sams Teach Yourself Django in 24 Hours', 1), + ('The Definitive Guide to Django: Web Development Done Right', 2) + ] + ) + + @skipUnlessDBFeature('allows_group_by_selected_pks') + def test_aggregate_ummanaged_model_columns(self): + """ + Unmanaged models are sometimes used to represent database views which + may not allow grouping by selected primary key. + """ + def assertQuerysetResults(queryset): + self.assertEqual( + [(b.name, b.num_authors) for b in queryset.order_by('name')], + [ + ('Artificial Intelligence: A Modern Approach', 2), + ('Paradigms of Artificial Intelligence Programming: Case Studies in Common Lisp', 1), + ('Practical Django Projects', 1), + ('Python Web Development with Django', 3), + ('Sams Teach Yourself Django in 24 Hours', 1), + ('The Definitive Guide to Django: Web Development Done Right', 2), + ] + ) + queryset = Book.objects.select_related('contact').annotate(num_authors=Count('authors')) + # Unmanaged origin model. + with mock.patch.object(Book._meta, 'managed', False): + _, _, grouping = queryset.query.get_compiler(using='default').pre_sql_setup() + self.assertEqual(len(grouping), len(Book._meta.fields) + 1) + for index, field in enumerate(Book._meta.fields): + self.assertIn(field.name, grouping[index][0]) + self.assertIn(Author._meta.pk.name, grouping[-1][0]) + assertQuerysetResults(queryset) + # Unmanaged related model. + with mock.patch.object(Author._meta, 'managed', False): + _, _, grouping = queryset.query.get_compiler(using='default').pre_sql_setup() + self.assertEqual(len(grouping), len(Author._meta.fields) + 1) + self.assertIn(Book._meta.pk.name, grouping[0][0]) + for index, field in enumerate(Author._meta.fields): + self.assertIn(field.name, grouping[index + 1][0]) + assertQuerysetResults(queryset) + + def test_reverse_join_trimming(self): + qs = Author.objects.annotate(Count('book_contact_set__contact')) + self.assertIn(' JOIN ', str(qs.query)) + + def test_aggregation_with_generic_reverse_relation(self): + """ + Regression test for #10870: Aggregates with joins ignore extra + filters provided by setup_joins + + tests aggregations with generic reverse relations + """ + django_book = Book.objects.get(name='Practical Django Projects') + ItemTag.objects.create( + object_id=django_book.id, tag='intermediate', + content_type=ContentType.objects.get_for_model(django_book), + ) + ItemTag.objects.create( + object_id=django_book.id, tag='django', + content_type=ContentType.objects.get_for_model(django_book), + ) + # Assign a tag to model with same PK as the book above. If the JOIN + # used in aggregation doesn't have content type as part of the + # condition the annotation will also count the 'hi mom' tag for b. + wmpk = WithManualPK.objects.create(id=django_book.pk) + ItemTag.objects.create( + object_id=wmpk.id, tag='hi mom', + content_type=ContentType.objects.get_for_model(wmpk), + ) + ai_book = Book.objects.get(name__startswith='Paradigms of Artificial Intelligence') + ItemTag.objects.create( + object_id=ai_book.id, tag='intermediate', + content_type=ContentType.objects.get_for_model(ai_book), + ) + + self.assertEqual(Book.objects.aggregate(Count('tags')), {'tags__count': 3}) + results = Book.objects.annotate(Count('tags')).order_by('-tags__count', 'name') + self.assertEqual( + [(b.name, b.tags__count) for b in results], + [ + ('Practical Django Projects', 2), + ('Paradigms of Artificial Intelligence Programming: Case Studies in Common Lisp', 1), + ('Artificial Intelligence: A Modern Approach', 0), + ('Python Web Development with Django', 0), + ('Sams Teach Yourself Django in 24 Hours', 0), + ('The Definitive Guide to Django: Web Development Done Right', 0) + ] + ) + + def test_negated_aggregation(self): + expected_results = Author.objects.exclude( + pk__in=Author.objects.annotate(book_cnt=Count('book')).filter(book_cnt=2) + ).order_by('name') + expected_results = [a.name for a in expected_results] + qs = Author.objects.annotate(book_cnt=Count('book')).exclude( + Q(book_cnt=2), Q(book_cnt=2)).order_by('name') + self.assertQuerysetEqual( + qs, + expected_results, + lambda b: b.name + ) + expected_results = Author.objects.exclude( + pk__in=Author.objects.annotate(book_cnt=Count('book')).filter(book_cnt=2) + ).order_by('name') + expected_results = [a.name for a in expected_results] + qs = Author.objects.annotate(book_cnt=Count('book')).exclude(Q(book_cnt=2) | Q(book_cnt=2)).order_by('name') + self.assertQuerysetEqual( + qs, + expected_results, + lambda b: b.name + ) + + def test_name_filters(self): + qs = Author.objects.annotate(Count('book')).filter( + Q(book__count__exact=2) | Q(name='Adrian Holovaty') + ).order_by('name') + self.assertQuerysetEqual( + qs, + ['Adrian Holovaty', 'Peter Norvig'], + lambda b: b.name + ) + + def test_name_expressions(self): + # Aggregates are spotted correctly from F objects. + # Note that Adrian's age is 34 in the fixtures, and he has one book + # so both conditions match one author. + qs = Author.objects.annotate(Count('book')).filter( + Q(name='Peter Norvig') | Q(age=F('book__count') + 33) + ).order_by('name') + self.assertQuerysetEqual( + qs, + ['Adrian Holovaty', 'Peter Norvig'], + lambda b: b.name + ) + + def test_ticket_11293(self): + q1 = Q(price__gt=50) + q2 = Q(authors__count__gt=1) + query = Book.objects.annotate(Count('authors')).filter( + q1 | q2).order_by('pk') + self.assertQuerysetEqual( + query, [1, 4, 5, 6], + lambda b: b.pk) + + def test_ticket_11293_q_immutable(self): + """ + Splitting a q object to parts for where/having doesn't alter + the original q-object. + """ + q1 = Q(isbn='') + q2 = Q(authors__count__gt=1) + query = Book.objects.annotate(Count('authors')) + query.filter(q1 | q2) + self.assertEqual(len(q2.children), 1) + + def test_fobj_group_by(self): + """ + An F() object referring to related column works correctly in group by. + """ + qs = Book.objects.annotate( + account=Count('authors') + ).filter( + account=F('publisher__num_awards') + ) + self.assertQuerysetEqual( + qs, ['Sams Teach Yourself Django in 24 Hours'], + lambda b: b.name) + + def test_annotate_reserved_word(self): + """ + Regression #18333 - Ensure annotated column name is properly quoted. + """ + vals = Book.objects.annotate(select=Count('authors__id')).aggregate(Sum('select'), Avg('select')) + self.assertEqual(vals, { + 'select__sum': 10, + 'select__avg': Approximate(1.666, places=2), + }) + + def test_annotate_on_relation(self): + book = Book.objects.annotate(avg_price=Avg('price'), publisher_name=F('publisher__name')).get(pk=self.b1.pk) + self.assertEqual(book.avg_price, 30.00) + self.assertEqual(book.publisher_name, "Apress") + + def test_aggregate_on_relation(self): + # A query with an existing annotation aggregation on a relation should + # succeed. + qs = Book.objects.annotate(avg_price=Avg('price')).aggregate( + publisher_awards=Sum('publisher__num_awards') + ) + self.assertEqual(qs['publisher_awards'], 30) + + def test_annotate_distinct_aggregate(self): + # There are three books with rating of 4.0 and two of the books have + # the same price. Hence, the distinct removes one rating of 4.0 + # from the results. + vals1 = Book.objects.values('rating', 'price').distinct().aggregate(result=Sum('rating')) + vals2 = Book.objects.aggregate(result=Sum('rating') - Value(4.0)) + self.assertEqual(vals1, vals2) + + def test_annotate_values_list_flat(self): + """Find ages that are shared by at least two authors.""" + qs = Author.objects.values_list('age', flat=True).annotate(age_count=Count('age')).filter(age_count__gt=1) + self.assertSequenceEqual(qs, [29]) + + +class JoinPromotionTests(TestCase): + def test_ticket_21150(self): + b = Bravo.objects.create() + c = Charlie.objects.create(bravo=b) + qs = Charlie.objects.select_related('alfa').annotate(Count('bravo__charlie')) + self.assertSequenceEqual(qs, [c]) + self.assertIs(qs[0].alfa, None) + a = Alfa.objects.create() + c.alfa = a + c.save() + # Force re-evaluation + qs = qs.all() + self.assertSequenceEqual(qs, [c]) + self.assertEqual(qs[0].alfa, a) + + def test_existing_join_not_promoted(self): + # No promotion for existing joins + qs = Charlie.objects.filter(alfa__name__isnull=False).annotate(Count('alfa__name')) + self.assertIn(' INNER JOIN ', str(qs.query)) + # Also, the existing join is unpromoted when doing filtering for already + # promoted join. + qs = Charlie.objects.annotate(Count('alfa__name')).filter(alfa__name__isnull=False) + self.assertIn(' INNER JOIN ', str(qs.query)) + # But, as the join is nullable first use by annotate will be LOUTER + qs = Charlie.objects.annotate(Count('alfa__name')) + self.assertIn(' LEFT OUTER JOIN ', str(qs.query)) + + def test_non_nullable_fk_not_promoted(self): + qs = Book.objects.annotate(Count('contact__name')) + self.assertIn(' INNER JOIN ', str(qs.query)) + + +class SelfReferentialFKTests(TestCase): + def test_ticket_24748(self): + t1 = SelfRefFK.objects.create(name='t1') + SelfRefFK.objects.create(name='t2', parent=t1) + SelfRefFK.objects.create(name='t3', parent=t1) + self.assertQuerysetEqual( + SelfRefFK.objects.annotate(num_children=Count('children')).order_by('name'), + [('t1', 2), ('t2', 0), ('t3', 0)], + lambda x: (x.name, x.num_children) + ) diff --git a/tests/bulk_create/__init__.py b/tests/bulk_create/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/bulk_create/models.py b/tests/bulk_create/models.py new file mode 100644 index 00000000..c302a70b --- /dev/null +++ b/tests/bulk_create/models.py @@ -0,0 +1,53 @@ +from django.db import models + + +class Country(models.Model): + name = models.CharField(max_length=255) + iso_two_letter = models.CharField(max_length=2) + + +class ProxyCountry(Country): + class Meta: + proxy = True + + +class ProxyProxyCountry(ProxyCountry): + class Meta: + proxy = True + + +class ProxyMultiCountry(ProxyCountry): + pass + + +class ProxyMultiProxyCountry(ProxyMultiCountry): + class Meta: + proxy = True + + +class Place(models.Model): + name = models.CharField(max_length=100) + + class Meta: + abstract = True + + +class Restaurant(Place): + pass + + +class Pizzeria(Restaurant): + pass + + +class State(models.Model): + two_letter_code = models.CharField(max_length=2, primary_key=True) + + +class TwoFields(models.Model): + f1 = models.IntegerField(unique=True) + f2 = models.IntegerField(unique=True) + + +class NoFields(models.Model): + pass diff --git a/tests/bulk_create/tests.py b/tests/bulk_create/tests.py new file mode 100644 index 00000000..210be56c --- /dev/null +++ b/tests/bulk_create/tests.py @@ -0,0 +1,235 @@ +from __future__ import unicode_literals + +from operator import attrgetter + +from django.db import connection +from django.db.models import Value +from django.db.models.functions import Lower +from django.test import ( + TestCase, override_settings, skipIfDBFeature, skipUnlessDBFeature, +) + +from .models import ( + Country, NoFields, Pizzeria, ProxyCountry, ProxyMultiCountry, + ProxyMultiProxyCountry, ProxyProxyCountry, Restaurant, State, TwoFields, +) + + +class BulkCreateTests(TestCase): + def setUp(self): + self.data = [ + Country(name="United States of America", iso_two_letter="US"), + Country(name="The Netherlands", iso_two_letter="NL"), + Country(name="Germany", iso_two_letter="DE"), + Country(name="Czech Republic", iso_two_letter="CZ") + ] + + def test_simple(self): + created = Country.objects.bulk_create(self.data) + self.assertEqual(len(created), 4) + self.assertQuerysetEqual(Country.objects.order_by("-name"), [ + "United States of America", "The Netherlands", "Germany", "Czech Republic" + ], attrgetter("name")) + + created = Country.objects.bulk_create([]) + self.assertEqual(created, []) + self.assertEqual(Country.objects.count(), 4) + + @skipUnlessDBFeature('has_bulk_insert') + def test_efficiency(self): + with self.assertNumQueries(1): + Country.objects.bulk_create(self.data) + + def test_multi_table_inheritance_unsupported(self): + expected_message = "Can't bulk create a multi-table inherited model" + with self.assertRaisesMessage(ValueError, expected_message): + Pizzeria.objects.bulk_create([ + Pizzeria(name="The Art of Pizza"), + ]) + with self.assertRaisesMessage(ValueError, expected_message): + ProxyMultiCountry.objects.bulk_create([ + ProxyMultiCountry(name="Fillory", iso_two_letter="FL"), + ]) + with self.assertRaisesMessage(ValueError, expected_message): + ProxyMultiProxyCountry.objects.bulk_create([ + ProxyMultiProxyCountry(name="Fillory", iso_two_letter="FL"), + ]) + + def test_proxy_inheritance_supported(self): + ProxyCountry.objects.bulk_create([ + ProxyCountry(name="Qwghlm", iso_two_letter="QW"), + Country(name="Tortall", iso_two_letter="TA"), + ]) + self.assertQuerysetEqual(ProxyCountry.objects.all(), { + "Qwghlm", "Tortall" + }, attrgetter("name"), ordered=False) + + ProxyProxyCountry.objects.bulk_create([ + ProxyProxyCountry(name="Netherlands", iso_two_letter="NT"), + ]) + self.assertQuerysetEqual(ProxyProxyCountry.objects.all(), { + "Qwghlm", "Tortall", "Netherlands", + }, attrgetter("name"), ordered=False) + + def test_non_auto_increment_pk(self): + State.objects.bulk_create([ + State(two_letter_code=s) + for s in ["IL", "NY", "CA", "ME"] + ]) + self.assertQuerysetEqual(State.objects.order_by("two_letter_code"), [ + "CA", "IL", "ME", "NY", + ], attrgetter("two_letter_code")) + + @skipUnlessDBFeature('has_bulk_insert') + def test_non_auto_increment_pk_efficiency(self): + with self.assertNumQueries(1): + State.objects.bulk_create([ + State(two_letter_code=s) + for s in ["IL", "NY", "CA", "ME"] + ]) + self.assertQuerysetEqual(State.objects.order_by("two_letter_code"), [ + "CA", "IL", "ME", "NY", + ], attrgetter("two_letter_code")) + + @skipIfDBFeature('allows_auto_pk_0') + def test_zero_as_autoval(self): + """ + Zero as id for AutoField should raise exception in MySQL, because MySQL + does not allow zero for automatic primary key. + """ + valid_country = Country(name='Germany', iso_two_letter='DE') + invalid_country = Country(id=0, name='Poland', iso_two_letter='PL') + with self.assertRaises(ValueError): + Country.objects.bulk_create([valid_country, invalid_country]) + + def test_batch_same_vals(self): + # Sqlite had a problem where all the same-valued models were + # collapsed to one insert. + Restaurant.objects.bulk_create([ + Restaurant(name='foo') for i in range(0, 2) + ]) + self.assertEqual(Restaurant.objects.count(), 2) + + def test_large_batch(self): + with override_settings(DEBUG=True): + connection.queries_log.clear() + TwoFields.objects.bulk_create([ + TwoFields(f1=i, f2=i + 1) for i in range(0, 1001) + ]) + self.assertEqual(TwoFields.objects.count(), 1001) + self.assertEqual( + TwoFields.objects.filter(f1__gte=450, f1__lte=550).count(), + 101) + self.assertEqual(TwoFields.objects.filter(f2__gte=901).count(), 101) + + @skipUnlessDBFeature('has_bulk_insert') + def test_large_single_field_batch(self): + # SQLite had a problem with more than 500 UNIONed selects in single + # query. + Restaurant.objects.bulk_create([ + Restaurant() for i in range(0, 501) + ]) + + @skipUnlessDBFeature('has_bulk_insert') + def test_large_batch_efficiency(self): + with override_settings(DEBUG=True): + connection.queries_log.clear() + TwoFields.objects.bulk_create([ + TwoFields(f1=i, f2=i + 1) for i in range(0, 1001) + ]) + self.assertLess(len(connection.queries), 10) + + def test_large_batch_mixed(self): + """ + Test inserting a large batch with objects having primary key set + mixed together with objects without PK set. + """ + with override_settings(DEBUG=True): + connection.queries_log.clear() + TwoFields.objects.bulk_create([ + TwoFields(id=i if i % 2 == 0 else None, f1=i, f2=i + 1) + for i in range(100000, 101000)]) + self.assertEqual(TwoFields.objects.count(), 1000) + # We can't assume much about the ID's created, except that the above + # created IDs must exist. + id_range = range(100000, 101000, 2) + self.assertEqual(TwoFields.objects.filter(id__in=id_range).count(), 500) + self.assertEqual(TwoFields.objects.exclude(id__in=id_range).count(), 500) + + @skipUnlessDBFeature('has_bulk_insert') + def test_large_batch_mixed_efficiency(self): + """ + Test inserting a large batch with objects having primary key set + mixed together with objects without PK set. + """ + with override_settings(DEBUG=True): + connection.queries_log.clear() + TwoFields.objects.bulk_create([ + TwoFields(id=i if i % 2 == 0 else None, f1=i, f2=i + 1) + for i in range(100000, 101000)]) + self.assertLess(len(connection.queries), 10) + + def test_explicit_batch_size(self): + objs = [TwoFields(f1=i, f2=i) for i in range(0, 4)] + num_objs = len(objs) + TwoFields.objects.bulk_create(objs, batch_size=1) + self.assertEqual(TwoFields.objects.count(), num_objs) + TwoFields.objects.all().delete() + TwoFields.objects.bulk_create(objs, batch_size=2) + self.assertEqual(TwoFields.objects.count(), num_objs) + TwoFields.objects.all().delete() + TwoFields.objects.bulk_create(objs, batch_size=3) + self.assertEqual(TwoFields.objects.count(), num_objs) + TwoFields.objects.all().delete() + TwoFields.objects.bulk_create(objs, batch_size=num_objs) + self.assertEqual(TwoFields.objects.count(), num_objs) + + def test_empty_model(self): + self.skipTest("TODO fix ZeroDivisionError: integer division or modulo by zero") + NoFields.objects.bulk_create([NoFields() for i in range(2)]) + self.assertEqual(NoFields.objects.count(), 2) + + @skipUnlessDBFeature('has_bulk_insert') + def test_explicit_batch_size_efficiency(self): + objs = [TwoFields(f1=i, f2=i) for i in range(0, 100)] + with self.assertNumQueries(2): + TwoFields.objects.bulk_create(objs, 50) + TwoFields.objects.all().delete() + with self.assertNumQueries(1): + TwoFields.objects.bulk_create(objs, len(objs)) + + @skipUnlessDBFeature('has_bulk_insert') + def test_bulk_insert_expressions(self): + Restaurant.objects.bulk_create([ + Restaurant(name="Sam's Shake Shack"), + Restaurant(name=Lower(Value("Betty's Beetroot Bar"))) + ]) + bbb = Restaurant.objects.filter(name="betty's beetroot bar") + self.assertEqual(bbb.count(), 1) + + @skipUnlessDBFeature('can_return_ids_from_bulk_insert') + def test_set_pk_and_insert_single_item(self): + with self.assertNumQueries(1): + countries = Country.objects.bulk_create([self.data[0]]) + self.assertEqual(len(countries), 1) + self.assertEqual(Country.objects.get(pk=countries[0].pk), countries[0]) + + @skipUnlessDBFeature('can_return_ids_from_bulk_insert') + def test_set_pk_and_query_efficiency(self): + with self.assertNumQueries(1): + countries = Country.objects.bulk_create(self.data) + self.assertEqual(len(countries), 4) + self.assertEqual(Country.objects.get(pk=countries[0].pk), countries[0]) + self.assertEqual(Country.objects.get(pk=countries[1].pk), countries[1]) + self.assertEqual(Country.objects.get(pk=countries[2].pk), countries[2]) + self.assertEqual(Country.objects.get(pk=countries[3].pk), countries[3]) + + @skipUnlessDBFeature('can_return_ids_from_bulk_insert') + def test_set_state(self): + country_nl = Country(name='Netherlands', iso_two_letter='NL') + country_be = Country(name='Belgium', iso_two_letter='BE') + Country.objects.bulk_create([country_nl]) + country_be.save() + # Objects save via bulk_create() and save() should have equal state. + self.assertEqual(country_nl._state.adding, country_be._state.adding) + self.assertEqual(country_nl._state.db, country_be._state.db) diff --git a/tests/custom_columns/__init__.py b/tests/custom_columns/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/custom_columns/models.py b/tests/custom_columns/models.py new file mode 100644 index 00000000..3f619a7f --- /dev/null +++ b/tests/custom_columns/models.py @@ -0,0 +1,55 @@ +""" +Custom column/table names + +If your database column name is different than your model attribute, use the +``db_column`` parameter. Note that you'll use the field's name, not its column +name, in API usage. + +If your database table name is different than your model name, use the +``db_table`` Meta attribute. This has no effect on the API used to +query the database. + +If you need to use a table name for a many-to-many relationship that differs +from the default generated name, use the ``db_table`` parameter on the +``ManyToManyField``. This has no effect on the API for querying the database. + +""" + +from __future__ import unicode_literals + +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +@python_2_unicode_compatible +class Author(models.Model): + Author_ID = models.AutoField(primary_key=True, db_column='Author ID') + first_name = models.CharField(max_length=30, db_column='firstname') + last_name = models.CharField(max_length=30, db_column='last') + + def __str__(self): + return '%s %s' % (self.first_name, self.last_name) + + class Meta: + db_table = 'my_author_table' + ordering = ('last_name', 'first_name') + + +@python_2_unicode_compatible +class Article(models.Model): + Article_ID = models.AutoField(primary_key=True, db_column='Article ID') + headline = models.CharField(max_length=100) + authors = models.ManyToManyField(Author, db_table='my_m2m_table') + primary_author = models.ForeignKey( + Author, + models.SET_NULL, + db_column='Author ID', + related_name='primary_set', + null=True, + ) + + def __str__(self): + return self.headline + + class Meta: + ordering = ('headline',) diff --git a/tests/custom_columns/tests.py b/tests/custom_columns/tests.py new file mode 100644 index 00000000..7102e4fd --- /dev/null +++ b/tests/custom_columns/tests.py @@ -0,0 +1,123 @@ +from __future__ import unicode_literals + +from django.core.exceptions import FieldError +from django.test import TestCase +from django.utils import six + +from .models import Article, Author + + +class CustomColumnsTests(TestCase): + + def setUp(self): + self.a1 = Author.objects.create(first_name="John", last_name="Smith") + self.a2 = Author.objects.create(first_name="Peter", last_name="Jones") + self.authors = [self.a1, self.a2] + + self.article = Article.objects.create(headline="Django lets you build Web apps easily", primary_author=self.a1) + self.article.authors.set(self.authors) + + def test_query_all_available_authors(self): + self.assertQuerysetEqual( + Author.objects.all(), [ + "Peter Jones", "John Smith", + ], + six.text_type + ) + + def test_get_first_name(self): + self.assertEqual( + Author.objects.get(first_name__exact="John"), + self.a1, + ) + + def test_filter_first_name(self): + self.assertQuerysetEqual( + Author.objects.filter(first_name__exact="John"), [ + "John Smith", + ], + six.text_type + ) + + def test_field_error(self): + with self.assertRaises(FieldError): + Author.objects.filter(firstname__exact="John") + + def test_attribute_error(self): + with self.assertRaises(AttributeError): + self.a1.firstname + + with self.assertRaises(AttributeError): + self.a1.last + + def test_get_all_authors_for_an_article(self): + self.assertQuerysetEqual( + self.article.authors.all(), [ + "Peter Jones", + "John Smith", + ], + six.text_type + ) + + def test_get_all_articles_for_an_author(self): + self.assertQuerysetEqual( + self.a1.article_set.all(), [ + "Django lets you build Web apps easily", + ], + lambda a: a.headline + ) + + def test_get_author_m2m_relation(self): + self.assertQuerysetEqual( + self.article.authors.filter(last_name='Jones'), [ + "Peter Jones" + ], + six.text_type + ) + + def test_author_querying(self): + self.assertQuerysetEqual( + Author.objects.all().order_by('last_name'), + ['', ''] + ) + + def test_author_filtering(self): + self.assertQuerysetEqual( + Author.objects.filter(first_name__exact='John'), + [''] + ) + + def test_author_get(self): + self.assertEqual(self.a1, Author.objects.get(first_name__exact='John')) + + def test_filter_on_nonexistent_field(self): + msg = ( + "Cannot resolve keyword 'firstname' into field. Choices are: " + "Author_ID, article, first_name, last_name, primary_set" + ) + with self.assertRaisesMessage(FieldError, msg): + Author.objects.filter(firstname__exact='John') + + def test_author_get_attributes(self): + a = Author.objects.get(last_name__exact='Smith') + self.assertEqual('John', a.first_name) + self.assertEqual('Smith', a.last_name) + with self.assertRaisesMessage(AttributeError, "'Author' object has no attribute 'firstname'"): + getattr(a, 'firstname') + + with self.assertRaisesMessage(AttributeError, "'Author' object has no attribute 'last'"): + getattr(a, 'last') + + def test_m2m_table(self): + self.assertQuerysetEqual( + self.article.authors.all().order_by('last_name'), + ['', ''] + ) + self.assertQuerysetEqual( + self.a1.article_set.all(), + [''] + ) + self.assertQuerysetEqual( + self.article.authors.filter(last_name='Jones'), + [''] + ) diff --git a/tests/custom_pk/__init__.py b/tests/custom_pk/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/custom_pk/fields.py b/tests/custom_pk/fields.py new file mode 100644 index 00000000..5bd249df --- /dev/null +++ b/tests/custom_pk/fields.py @@ -0,0 +1,60 @@ +import random +import string + +from django.db import models + + +class MyWrapper: + def __init__(self, value): + self.value = value + + def __repr__(self): + return "<%s: %s>" % (self.__class__.__name__, self.value) + + def __str__(self): + return self.value + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.value == other.value + return self.value == other + + +class MyAutoField(models.CharField): + + def __init__(self, *args, **kwargs): + kwargs['max_length'] = 10 + super().__init__(*args, **kwargs) + + def pre_save(self, instance, add): + value = getattr(instance, self.attname, None) + if not value: + value = MyWrapper(''.join(random.sample(string.ascii_lowercase, 10))) + setattr(instance, self.attname, value) + return value + + def to_python(self, value): + if not value: + return + if not isinstance(value, MyWrapper): + value = MyWrapper(value) + return value + + def from_db_value(self, value, expression, connection): + if not value: + return + return MyWrapper(value) + + def get_db_prep_save(self, value, connection): + if not value: + return + if isinstance(value, MyWrapper): + return str(value) + return value + + def get_db_prep_value(self, value, connection, prepared=False): + if not value: + return + if isinstance(value, MyWrapper): + return str(value) + return value diff --git a/tests/custom_pk/models.py b/tests/custom_pk/models.py new file mode 100644 index 00000000..0b272c11 --- /dev/null +++ b/tests/custom_pk/models.py @@ -0,0 +1,44 @@ +""" +Using a custom primary key + +By default, Django adds an ``"id"`` field to each model. But you can override +this behavior by explicitly adding ``primary_key=True`` to a field. +""" + +from django.db import models + +from .fields import MyAutoField + + +class Employee(models.Model): + employee_code = models.IntegerField(primary_key=True, db_column='code') + first_name = models.CharField(max_length=20) + last_name = models.CharField(max_length=20) + + class Meta: + ordering = ('last_name', 'first_name') + + def __str__(self): + return "%s %s" % (self.first_name, self.last_name) + + +class Business(models.Model): + name = models.CharField(max_length=20, primary_key=True) + employees = models.ManyToManyField(Employee) + + class Meta: + verbose_name_plural = 'businesses' + + def __str__(self): + return self.name + + +class Bar(models.Model): + id = MyAutoField(primary_key=True, db_index=True) + + def __str__(self): + return repr(self.pk) + + +class Foo(models.Model): + bar = models.ForeignKey(Bar, models.CASCADE) diff --git a/tests/custom_pk/tests.py b/tests/custom_pk/tests.py new file mode 100644 index 00000000..da0cff14 --- /dev/null +++ b/tests/custom_pk/tests.py @@ -0,0 +1,232 @@ +from django.db import IntegrityError, transaction +from django.test import TestCase, skipIfDBFeature + +from .models import Bar, Business, Employee, Foo + + +class BasicCustomPKTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.dan = Employee.objects.create( + employee_code=123, first_name="Dan", last_name="Jones", + ) + cls.fran = Employee.objects.create( + employee_code=456, first_name="Fran", last_name="Bones", + ) + cls.business = Business.objects.create(name="Sears") + cls.business.employees.add(cls.dan, cls.fran) + + def test_querysets(self): + """ + Both pk and custom attribute_name can be used in filter and friends + """ + self.assertQuerysetEqual( + Employee.objects.filter(pk=123), [ + "Dan Jones", + ], + str + ) + + self.assertQuerysetEqual( + Employee.objects.filter(employee_code=123), [ + "Dan Jones", + ], + str + ) + + self.assertQuerysetEqual( + Employee.objects.filter(pk__in=[123, 456]), [ + "Fran Bones", + "Dan Jones", + ], + str + ) + + self.assertQuerysetEqual( + Employee.objects.all(), [ + "Fran Bones", + "Dan Jones", + ], + str + ) + + self.assertQuerysetEqual( + Business.objects.filter(name="Sears"), [ + "Sears" + ], + lambda b: b.name + ) + self.assertQuerysetEqual( + Business.objects.filter(pk="Sears"), [ + "Sears", + ], + lambda b: b.name + ) + + def test_querysets_related_name(self): + """ + Custom pk doesn't affect related_name based lookups + """ + self.assertQuerysetEqual( + self.business.employees.all(), [ + "Fran Bones", + "Dan Jones", + ], + str + ) + self.assertQuerysetEqual( + self.fran.business_set.all(), [ + "Sears", + ], + lambda b: b.name + ) + + def test_querysets_relational(self): + """ + Queries across tables, involving primary key + """ + self.assertQuerysetEqual( + Employee.objects.filter(business__name="Sears"), [ + "Fran Bones", + "Dan Jones", + ], + str, + ) + self.assertQuerysetEqual( + Employee.objects.filter(business__pk="Sears"), [ + "Fran Bones", + "Dan Jones", + ], + str, + ) + + self.assertQuerysetEqual( + Business.objects.filter(employees__employee_code=123), [ + "Sears", + ], + lambda b: b.name + ) + self.assertQuerysetEqual( + Business.objects.filter(employees__pk=123), [ + "Sears", + ], + lambda b: b.name, + ) + + self.assertQuerysetEqual( + Business.objects.filter(employees__first_name__startswith="Fran"), [ + "Sears", + ], + lambda b: b.name + ) + + def test_get(self): + """ + Get can accept pk or the real attribute name + """ + self.assertEqual(Employee.objects.get(pk=123), self.dan) + self.assertEqual(Employee.objects.get(pk=456), self.fran) + + with self.assertRaises(Employee.DoesNotExist): + Employee.objects.get(pk=42) + + # Use the name of the primary key, rather than pk. + self.assertEqual(Employee.objects.get(employee_code=123), self.dan) + + def test_pk_attributes(self): + """ + pk and attribute name are available on the model + No default id attribute is added + """ + # pk can be used as a substitute for the primary key. + # The primary key can be accessed via the pk property on the model. + e = Employee.objects.get(pk=123) + self.assertEqual(e.pk, 123) + # Or we can use the real attribute name for the primary key: + self.assertEqual(e.employee_code, 123) + + with self.assertRaisesMessage(AttributeError, "'Employee' object has no attribute 'id'"): + e.id + + def test_in_bulk(self): + """ + Custom pks work with in_bulk, both for integer and non-integer types + """ + emps = Employee.objects.in_bulk([123, 456]) + self.assertEqual(emps[123], self.dan) + + self.assertEqual(Business.objects.in_bulk(["Sears"]), { + "Sears": self.business, + }) + + def test_save(self): + """ + custom pks do not affect save + """ + fran = Employee.objects.get(pk=456) + fran.last_name = "Jones" + fran.save() + + self.assertQuerysetEqual( + Employee.objects.filter(last_name="Jones"), [ + "Dan Jones", + "Fran Jones", + ], + str + ) + + +class CustomPKTests(TestCase): + def test_custom_pk_create(self): + """ + New objects can be created both with pk and the custom name + """ + Employee.objects.create(employee_code=1234, first_name="Foo", last_name="Bar") + Employee.objects.create(pk=1235, first_name="Foo", last_name="Baz") + Business.objects.create(name="Bears") + Business.objects.create(pk="Tears") + + def test_unicode_pk(self): + # Primary key may be unicode string + Business.objects.create(name='jaźń') + + def test_unique_pk(self): + # The primary key must also obviously be unique, so trying to create a + # new object with the same primary key will fail. + Employee.objects.create( + employee_code=123, first_name="Frank", last_name="Jones" + ) + with self.assertRaises(IntegrityError): + with transaction.atomic(): + Employee.objects.create(employee_code=123, first_name="Fred", last_name="Jones") + + def test_zero_non_autoincrement_pk(self): + Employee.objects.create( + employee_code=0, first_name="Frank", last_name="Jones" + ) + employee = Employee.objects.get(pk=0) + self.assertEqual(employee.employee_code, 0) + + def test_custom_field_pk(self): + # Regression for #10785 -- Custom fields can be used for primary keys. + new_bar = Bar.objects.create() + new_foo = Foo.objects.create(bar=new_bar) + + f = Foo.objects.get(bar=new_bar.pk) + self.assertEqual(f, new_foo) + self.assertEqual(f.bar, new_bar) + + f = Foo.objects.get(bar=new_bar) + self.assertEqual(f, new_foo), + self.assertEqual(f.bar, new_bar) + + # SQLite lets objects be saved with an empty primary key, even though an + # integer is expected. So we can't check for an error being raised in that + # case for SQLite. Remove it from the suite for this next bit. + @skipIfDBFeature('supports_unspecified_pk') + def test_required_pk(self): + # The primary key must be specified, so an error is raised if you + # try to create an object without it. + with self.assertRaises(IntegrityError): + with transaction.atomic(): + Employee.objects.create(first_name="Tom", last_name="Smith") diff --git a/tests/datatypes/__init__.py b/tests/datatypes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/datatypes/models.py b/tests/datatypes/models.py new file mode 100644 index 00000000..cabe5297 --- /dev/null +++ b/tests/datatypes/models.py @@ -0,0 +1,29 @@ +""" +This is a basic model to test saving and loading boolean and date-related +types, which in the past were problematic for some database backends. +""" + +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +@python_2_unicode_compatible +class Donut(models.Model): + name = models.CharField(max_length=100) + is_frosted = models.BooleanField(default=False) + has_sprinkles = models.NullBooleanField() + baked_date = models.DateField(null=True) + baked_time = models.TimeField(null=True) + consumed_at = models.DateTimeField(null=True) + review = models.TextField() + + class Meta: + ordering = ('consumed_at',) + + def __str__(self): + return self.name + + +class RumBaba(models.Model): + baked_date = models.DateField(auto_now_add=True) + baked_timestamp = models.DateTimeField(auto_now_add=True) diff --git a/tests/datatypes/tests.py b/tests/datatypes/tests.py new file mode 100644 index 00000000..cf765677 --- /dev/null +++ b/tests/datatypes/tests.py @@ -0,0 +1,102 @@ +from __future__ import unicode_literals + +import datetime + +from django.test import TestCase, skipIfDBFeature +from django.utils import six +from django.utils.timezone import utc + +from .models import Donut, RumBaba + + +class DataTypesTestCase(TestCase): + + def test_boolean_type(self): + d = Donut(name='Apple Fritter') + self.assertFalse(d.is_frosted) + self.assertIsNone(d.has_sprinkles) + d.has_sprinkles = True + self.assertTrue(d.has_sprinkles) + + d.save() + + d2 = Donut.objects.get(name='Apple Fritter') + self.assertFalse(d2.is_frosted) + self.assertTrue(d2.has_sprinkles) + + def test_date_type(self): + d = Donut(name='Apple Fritter') + d.baked_date = datetime.date(year=1938, month=6, day=4) + d.baked_time = datetime.time(hour=5, minute=30) + d.consumed_at = datetime.datetime(year=2007, month=4, day=20, hour=16, minute=19, second=59) + d.save() + + d2 = Donut.objects.get(name='Apple Fritter') + self.assertEqual(d2.baked_date, datetime.date(1938, 6, 4)) + self.assertEqual(d2.baked_time, datetime.time(5, 30)) + self.assertEqual(d2.consumed_at, datetime.datetime(2007, 4, 20, 16, 19, 59)) + + def test_time_field(self): + # Test for ticket #12059: TimeField wrongly handling datetime.datetime object. + d = Donut(name='Apple Fritter') + d.baked_time = datetime.datetime(year=2007, month=4, day=20, hour=16, minute=19, second=59) + d.save() + + d2 = Donut.objects.get(name='Apple Fritter') + self.assertEqual(d2.baked_time, datetime.time(16, 19, 59)) + + def test_year_boundaries(self): + """Year boundary tests (ticket #3689)""" + Donut.objects.create( + name='Date Test 2007', + baked_date=datetime.datetime(year=2007, month=12, day=31), + consumed_at=datetime.datetime(year=2007, month=12, day=31, hour=23, minute=59, second=59), + ) + Donut.objects.create( + name='Date Test 2006', + baked_date=datetime.datetime(year=2006, month=1, day=1), + consumed_at=datetime.datetime(year=2006, month=1, day=1), + ) + self.assertEqual("Date Test 2007", Donut.objects.filter(baked_date__year=2007)[0].name) + self.assertEqual("Date Test 2006", Donut.objects.filter(baked_date__year=2006)[0].name) + + Donut.objects.create( + name='Apple Fritter', + consumed_at=datetime.datetime(year=2007, month=4, day=20, hour=16, minute=19, second=59), + ) + + self.assertEqual( + ['Apple Fritter', 'Date Test 2007'], + list(Donut.objects.filter(consumed_at__year=2007).order_by('name').values_list('name', flat=True)) + ) + self.assertEqual(0, Donut.objects.filter(consumed_at__year=2005).count()) + self.assertEqual(0, Donut.objects.filter(consumed_at__year=2008).count()) + + def test_textfields_unicode(self): + """Regression test for #10238: TextField values returned from the + database should be unicode.""" + d = Donut.objects.create(name='Jelly Donut', review='Outstanding') + newd = Donut.objects.get(id=d.id) + self.assertIsInstance(newd.review, six.text_type) + + @skipIfDBFeature('supports_timezones') + def test_error_on_timezone(self): + """Regression test for #8354: the MySQL and Oracle backends should raise + an error if given a timezone-aware datetime object.""" + self.skipTest("TODO fix AssertionError: ValueError not raised") + dt = datetime.datetime(2008, 8, 31, 16, 20, tzinfo=utc) + d = Donut(name='Bear claw', consumed_at=dt) + # MySQL backend does not support timezone-aware datetimes. + with self.assertRaises(ValueError): + d.save() + + def test_datefield_auto_now_add(self): + """Regression test for #10970, auto_now_add for DateField should store + a Python datetime.date, not a datetime.datetime""" + b = RumBaba.objects.create() + # Verify we didn't break DateTimeField behavior + self.assertIsInstance(b.baked_timestamp, datetime.datetime) + # We need to test this way because datetime.datetime inherits + # from datetime.date: + self.assertIsInstance(b.baked_date, datetime.date) + self.assertNotIsInstance(b.baked_date, datetime.datetime) diff --git a/tests/dates/__init__.py b/tests/dates/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/dates/models.py b/tests/dates/models.py new file mode 100644 index 00000000..74f9db28 --- /dev/null +++ b/tests/dates/models.py @@ -0,0 +1,27 @@ +from django.db import models +from django.utils import timezone + + +class Article(models.Model): + title = models.CharField(max_length=100) + pub_date = models.DateField() + pub_datetime = models.DateTimeField(default=timezone.now()) + + categories = models.ManyToManyField("Category", related_name="articles") + + def __str__(self): + return self.title + + +class Comment(models.Model): + article = models.ForeignKey(Article, models.CASCADE, related_name="comments") + text = models.TextField() + pub_date = models.DateField() + approval_date = models.DateField(null=True) + + def __str__(self): + return 'Comment to %s (%s)' % (self.article.title, self.pub_date) + + +class Category(models.Model): + name = models.CharField(max_length=255) diff --git a/tests/dates/tests.py b/tests/dates/tests.py new file mode 100644 index 00000000..ebdf0581 --- /dev/null +++ b/tests/dates/tests.py @@ -0,0 +1,136 @@ +import datetime +from unittest import skipUnless + +from django.core.exceptions import FieldError +from django.db import connection +from django.test import TestCase, override_settings + +from .models import Article, Category, Comment + + +class DatesTests(TestCase): + def test_related_model_traverse(self): + a1 = Article.objects.create( + title="First one", + pub_date=datetime.date(2005, 7, 28), + ) + a2 = Article.objects.create( + title="Another one", + pub_date=datetime.date(2010, 7, 28), + ) + a3 = Article.objects.create( + title="Third one, in the first day", + pub_date=datetime.date(2005, 7, 28), + ) + + a1.comments.create( + text="Im the HULK!", + pub_date=datetime.date(2005, 7, 28), + ) + a1.comments.create( + text="HULK SMASH!", + pub_date=datetime.date(2005, 7, 29), + ) + a2.comments.create( + text="LMAO", + pub_date=datetime.date(2010, 7, 28), + ) + a3.comments.create( + text="+1", + pub_date=datetime.date(2005, 8, 29), + ) + + c = Category.objects.create(name="serious-news") + c.articles.add(a1, a3) + + self.assertSequenceEqual( + Comment.objects.dates("article__pub_date", "year"), [ + datetime.date(2005, 1, 1), + datetime.date(2010, 1, 1), + ], + ) + self.assertSequenceEqual( + Comment.objects.dates("article__pub_date", "month"), [ + datetime.date(2005, 7, 1), + datetime.date(2010, 7, 1), + ], + ) + self.assertSequenceEqual( + Comment.objects.dates("article__pub_date", "week"), [ + datetime.date(2005, 7, 25), + datetime.date(2010, 7, 26), + ], + ) + self.assertSequenceEqual( + Comment.objects.dates("article__pub_date", "day"), [ + datetime.date(2005, 7, 28), + datetime.date(2010, 7, 28), + ], + ) + self.assertSequenceEqual( + Article.objects.dates("comments__pub_date", "day"), [ + datetime.date(2005, 7, 28), + datetime.date(2005, 7, 29), + datetime.date(2005, 8, 29), + datetime.date(2010, 7, 28), + ], + ) + self.assertQuerysetEqual( + Article.objects.dates("comments__approval_date", "day"), [] + ) + self.assertSequenceEqual( + Category.objects.dates("articles__pub_date", "day"), [ + datetime.date(2005, 7, 28), + ], + ) + + def test_dates_fails_when_no_arguments_are_provided(self): + with self.assertRaises(TypeError): + Article.objects.dates() + + def test_dates_fails_when_given_invalid_field_argument(self): + self.assertRaisesMessage( + FieldError, + "Cannot resolve keyword 'invalid_field' into field. Choices are: " + "categories, comments, id, pub_date, pub_datetime, title", + Article.objects.dates, + "invalid_field", + "year", + ) + + def test_dates_fails_when_given_invalid_kind_argument(self): + msg = "'kind' must be one of 'year', 'month', 'week', or 'day'." + with self.assertRaisesMessage(AssertionError, msg): + Article.objects.dates("pub_date", "bad_kind") + + def test_dates_fails_when_given_invalid_order_argument(self): + with self.assertRaisesMessage(AssertionError, "'order' must be either 'ASC' or 'DESC'."): + Article.objects.dates("pub_date", "year", order="bad order") + + @override_settings(USE_TZ=False) + def test_dates_trunc_datetime_fields(self): + Article.objects.bulk_create( + Article(pub_date=pub_datetime.date(), pub_datetime=pub_datetime) + for pub_datetime in [ + datetime.datetime(2015, 10, 21, 18, 1), + datetime.datetime(2015, 10, 21, 18, 2), + datetime.datetime(2015, 10, 22, 18, 1), + datetime.datetime(2015, 10, 22, 18, 2), + ] + ) + self.assertSequenceEqual( + Article.objects.dates('pub_datetime', 'day', order='ASC'), [ + datetime.date(2015, 10, 21), + datetime.date(2015, 10, 22), + ] + ) + + @skipUnless(connection.vendor == 'mysql', "Test checks MySQL query syntax") + def test_dates_avoid_datetime_cast(self): + Article.objects.create(pub_date=datetime.date(2015, 10, 21)) + for kind in ['day', 'month', 'year']: + qs = Article.objects.dates('pub_date', kind) + if kind == 'day': + self.assertIn('DATE(', str(qs.query)) + else: + self.assertIn(' AS DATE)', str(qs.query)) diff --git a/tests/datetimes/__init__.py b/tests/datetimes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/datetimes/models.py b/tests/datetimes/models.py new file mode 100644 index 00000000..2fcb72be --- /dev/null +++ b/tests/datetimes/models.py @@ -0,0 +1,31 @@ +from __future__ import unicode_literals + +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +@python_2_unicode_compatible +class Article(models.Model): + title = models.CharField(max_length=100) + pub_date = models.DateTimeField() + published_on = models.DateField(null=True) + + categories = models.ManyToManyField("Category", related_name="articles") + + def __str__(self): + return self.title + + +@python_2_unicode_compatible +class Comment(models.Model): + article = models.ForeignKey(Article, models.CASCADE, related_name="comments") + text = models.TextField() + pub_date = models.DateTimeField() + approval_date = models.DateTimeField(null=True) + + def __str__(self): + return 'Comment to %s (%s)' % (self.article.title, self.pub_date) + + +class Category(models.Model): + name = models.CharField(max_length=255) diff --git a/tests/datetimes/tests.py b/tests/datetimes/tests.py new file mode 100644 index 00000000..9117454d --- /dev/null +++ b/tests/datetimes/tests.py @@ -0,0 +1,153 @@ +from __future__ import unicode_literals + +import datetime + +import django +from django.test import TestCase, override_settings +from django.utils import timezone + +from .models import Article, Category, Comment + + +class DateTimesTests(TestCase): + def test_related_model_traverse(self): + a1 = Article.objects.create( + title="First one", + pub_date=datetime.datetime(2005, 7, 28, 9, 0, 0), + ) + a2 = Article.objects.create( + title="Another one", + pub_date=datetime.datetime(2010, 7, 28, 10, 0, 0), + ) + a3 = Article.objects.create( + title="Third one, in the first day", + pub_date=datetime.datetime(2005, 7, 28, 17, 0, 0), + ) + + a1.comments.create( + text="Im the HULK!", + pub_date=datetime.datetime(2005, 7, 28, 9, 30, 0), + ) + a1.comments.create( + text="HULK SMASH!", + pub_date=datetime.datetime(2005, 7, 29, 1, 30, 0), + ) + a2.comments.create( + text="LMAO", + pub_date=datetime.datetime(2010, 7, 28, 10, 10, 10), + ) + a3.comments.create( + text="+1", + pub_date=datetime.datetime(2005, 8, 29, 10, 10, 10), + ) + + c = Category.objects.create(name="serious-news") + c.articles.add(a1, a3) + + self.assertSequenceEqual( + Comment.objects.datetimes("article__pub_date", "year"), [ + datetime.datetime(2005, 1, 1), + datetime.datetime(2010, 1, 1), + ], + ) + self.assertSequenceEqual( + Comment.objects.datetimes("article__pub_date", "month"), [ + datetime.datetime(2005, 7, 1), + datetime.datetime(2010, 7, 1), + ], + ) + self.assertSequenceEqual( + Comment.objects.datetimes("article__pub_date", "day"), [ + datetime.datetime(2005, 7, 28), + datetime.datetime(2010, 7, 28), + ], + ) + self.assertSequenceEqual( + Article.objects.datetimes("comments__pub_date", "day"), [ + datetime.datetime(2005, 7, 28), + datetime.datetime(2005, 7, 29), + datetime.datetime(2005, 8, 29), + datetime.datetime(2010, 7, 28), + ], + ) + self.assertQuerysetEqual( + Article.objects.datetimes("comments__approval_date", "day"), [] + ) + self.assertSequenceEqual( + Category.objects.datetimes("articles__pub_date", "day"), [ + datetime.datetime(2005, 7, 28), + ], + ) + + @override_settings(USE_TZ=True) + def test_21432(self): + self.skipTest("TODO fix AssertionError: datet[20 chars], 9, 16, 59, 32, tzinfo=) != datet[20 chars], 9, 22, 59, 32, tzinfo=)") + now = timezone.localtime(timezone.now().replace(microsecond=0)) + Article.objects.create(title="First one", pub_date=now) + qs = Article.objects.datetimes('pub_date', 'second') + self.assertEqual(qs[0], now) + + def test_datetimes_returns_available_dates_for_given_scope_and_given_field(self): + pub_dates = [ + datetime.datetime(2005, 7, 28, 12, 15), + datetime.datetime(2005, 7, 29, 2, 15), + datetime.datetime(2005, 7, 30, 5, 15), + datetime.datetime(2005, 7, 31, 19, 15)] + for i, pub_date in enumerate(pub_dates): + Article(pub_date=pub_date, title='title #{}'.format(i)).save() + + self.assertQuerysetEqual( + Article.objects.datetimes('pub_date', 'year'), + ["datetime.datetime(2005, 1, 1, 0, 0)"]) + self.assertQuerysetEqual( + Article.objects.datetimes('pub_date', 'month'), + ["datetime.datetime(2005, 7, 1, 0, 0)"]) + self.assertQuerysetEqual( + Article.objects.datetimes('pub_date', 'day'), + ["datetime.datetime(2005, 7, 28, 0, 0)", + "datetime.datetime(2005, 7, 29, 0, 0)", + "datetime.datetime(2005, 7, 30, 0, 0)", + "datetime.datetime(2005, 7, 31, 0, 0)"]) + self.assertQuerysetEqual( + Article.objects.datetimes('pub_date', 'day', order='ASC'), + ["datetime.datetime(2005, 7, 28, 0, 0)", + "datetime.datetime(2005, 7, 29, 0, 0)", + "datetime.datetime(2005, 7, 30, 0, 0)", + "datetime.datetime(2005, 7, 31, 0, 0)"]) + self.assertQuerysetEqual( + Article.objects.datetimes('pub_date', 'day', order='DESC'), + ["datetime.datetime(2005, 7, 31, 0, 0)", + "datetime.datetime(2005, 7, 30, 0, 0)", + "datetime.datetime(2005, 7, 29, 0, 0)", + "datetime.datetime(2005, 7, 28, 0, 0)"]) + + def test_datetimes_has_lazy_iterator(self): + pub_dates = [ + datetime.datetime(2005, 7, 28, 12, 15), + datetime.datetime(2005, 7, 29, 2, 15), + datetime.datetime(2005, 7, 30, 5, 15), + datetime.datetime(2005, 7, 31, 19, 15)] + for i, pub_date in enumerate(pub_dates): + Article(pub_date=pub_date, title='title #{}'.format(i)).save() + # Use iterator() with datetimes() to return a generator that lazily + # requests each result one at a time, to save memory. + dates = [] + with self.assertNumQueries(0): + article_datetimes_iterator = Article.objects.datetimes('pub_date', 'day', order='DESC').iterator() + + with self.assertNumQueries(1): + for article in article_datetimes_iterator: + dates.append(article) + self.assertEqual(dates, [ + datetime.datetime(2005, 7, 31, 0, 0), + datetime.datetime(2005, 7, 30, 0, 0), + datetime.datetime(2005, 7, 29, 0, 0), + datetime.datetime(2005, 7, 28, 0, 0)]) + + def test_datetimes_disallows_date_fields(self): + if django.VERSION < (1, 10, 0): + self.skipTest("TODO fix AssertionError: 'published_on' isn't a DateTimeField.") + dt = datetime.datetime(2005, 7, 28, 12, 15) + Article.objects.create(pub_date=dt, published_on=dt.date(), title="Don't put dates into datetime functions!") + with self.assertRaisesMessage(ValueError, "Cannot truncate DateField 'published_on' to DateTimeField"): + list(Article.objects.datetimes('published_on', 'second')) diff --git a/tests/db_typecasts/__init__.py b/tests/db_typecasts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/db_typecasts/tests.py b/tests/db_typecasts/tests.py new file mode 100644 index 00000000..fa9eab16 --- /dev/null +++ b/tests/db_typecasts/tests.py @@ -0,0 +1,63 @@ +# Unit tests for typecast functions in django.db.backends.util + +import datetime +import unittest + +from django.db.backends import utils as typecasts +from django.utils import six + +TEST_CASES = { + 'typecast_date': ( + ('', None), + (None, None), + ('2005-08-11', datetime.date(2005, 8, 11)), + ('1990-01-01', datetime.date(1990, 1, 1)), + ), + 'typecast_time': ( + ('', None), + (None, None), + ('0:00:00', datetime.time(0, 0)), + ('0:30:00', datetime.time(0, 30)), + ('8:50:00', datetime.time(8, 50)), + ('08:50:00', datetime.time(8, 50)), + ('12:00:00', datetime.time(12, 00)), + ('12:30:00', datetime.time(12, 30)), + ('13:00:00', datetime.time(13, 00)), + ('23:59:00', datetime.time(23, 59)), + ('00:00:12', datetime.time(0, 0, 12)), + ('00:00:12.5', datetime.time(0, 0, 12, 500000)), + ('7:22:13.312', datetime.time(7, 22, 13, 312000)), + ('12:45:30.126631', datetime.time(12, 45, 30, 126631)), + ('12:45:30.126630', datetime.time(12, 45, 30, 126630)), + ('12:45:30.123456789', datetime.time(12, 45, 30, 123456)), + ), + 'typecast_timestamp': ( + ('', None), + (None, None), + ('2005-08-11 0:00:00', datetime.datetime(2005, 8, 11)), + ('2005-08-11 0:30:00', datetime.datetime(2005, 8, 11, 0, 30)), + ('2005-08-11 8:50:30', datetime.datetime(2005, 8, 11, 8, 50, 30)), + ('2005-08-11 8:50:30.123', datetime.datetime(2005, 8, 11, 8, 50, 30, 123000)), + ('2005-08-11 8:50:30.9', datetime.datetime(2005, 8, 11, 8, 50, 30, 900000)), + ('2005-08-11 8:50:30.312-05', datetime.datetime(2005, 8, 11, 8, 50, 30, 312000)), + ('2005-08-11 8:50:30.312+02', datetime.datetime(2005, 8, 11, 8, 50, 30, 312000)), + # ticket 14453 + ('2010-10-12 15:29:22.063202', datetime.datetime(2010, 10, 12, 15, 29, 22, 63202)), + ('2010-10-12 15:29:22.063202-03', datetime.datetime(2010, 10, 12, 15, 29, 22, 63202)), + ('2010-10-12 15:29:22.063202+04', datetime.datetime(2010, 10, 12, 15, 29, 22, 63202)), + ('2010-10-12 15:29:22.0632021', datetime.datetime(2010, 10, 12, 15, 29, 22, 63202)), + ('2010-10-12 15:29:22.0632029', datetime.datetime(2010, 10, 12, 15, 29, 22, 63202)), + ), +} + + +class DBTypeCasts(unittest.TestCase): + def test_typeCasts(self): + for k, v in six.iteritems(TEST_CASES): + for inpt, expected in v: + got = getattr(typecasts, k)(inpt) + self.assertEqual( + got, + expected, + "In %s: %r doesn't match %r. Got %r instead." % (k, inpt, expected, got) + ) diff --git a/tests/defer/__init__.py b/tests/defer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/defer/models.py b/tests/defer/models.py new file mode 100644 index 00000000..b36b1735 --- /dev/null +++ b/tests/defer/models.py @@ -0,0 +1,48 @@ +""" +Tests for defer() and only(). +""" + +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +class Secondary(models.Model): + first = models.CharField(max_length=50) + second = models.CharField(max_length=50) + + +@python_2_unicode_compatible +class Primary(models.Model): + name = models.CharField(max_length=50) + value = models.CharField(max_length=50) + related = models.ForeignKey(Secondary, models.CASCADE) + + def __str__(self): + return self.name + + +class Child(Primary): + pass + + +class BigChild(Primary): + other = models.CharField(max_length=50) + + +class ChildProxy(Child): + class Meta: + proxy = True + + +class RefreshPrimaryProxy(Primary): + class Meta: + proxy = True + + def refresh_from_db(self, using=None, fields=None, **kwargs): + # Reloads all deferred fields if any of the fields is deferred. + if fields is not None: + fields = set(fields) + deferred_fields = self.get_deferred_fields() + if fields.intersection(deferred_fields): + fields = fields.union(deferred_fields) + super(RefreshPrimaryProxy, self).refresh_from_db(using, fields, **kwargs) diff --git a/tests/defer/tests.py b/tests/defer/tests.py new file mode 100644 index 00000000..35fa77ea --- /dev/null +++ b/tests/defer/tests.py @@ -0,0 +1,274 @@ +from __future__ import unicode_literals + +from django.db.models.query_utils import InvalidQuery +from django.test import TestCase + +from .models import ( + BigChild, Child, ChildProxy, Primary, RefreshPrimaryProxy, Secondary, +) + + +class AssertionMixin(object): + def assert_delayed(self, obj, num): + """ + Instances with deferred fields look the same as normal instances when + we examine attribute values. Therefore, this method returns the number + of deferred fields on returned instances. + """ + count = len(obj.get_deferred_fields()) + self.assertEqual(count, num) + + +class DeferTests(AssertionMixin, TestCase): + @classmethod + def setUpTestData(cls): + cls.s1 = Secondary.objects.create(first="x1", second="y1") + cls.p1 = Primary.objects.create(name="p1", value="xx", related=cls.s1) + + def test_defer(self): + qs = Primary.objects.all() + self.assert_delayed(qs.defer("name")[0], 1) + self.assert_delayed(qs.defer("name").get(pk=self.p1.pk), 1) + self.assert_delayed(qs.defer("related__first")[0], 0) + self.assert_delayed(qs.defer("name").defer("value")[0], 2) + + def test_only(self): + # TODO: fix + return + qs = Primary.objects.all() + self.assert_delayed(qs.only("name")[0], 2) + self.assert_delayed(qs.only("name").get(pk=self.p1.pk), 2) + self.assert_delayed(qs.only("name").only("value")[0], 2) + self.assert_delayed(qs.only("related__first")[0], 2) + # Using 'pk' with only() should result in 3 deferred fields, namely all + # of them except the model's primary key see #15494 + self.assert_delayed(qs.only("pk")[0], 3) + # You can use 'pk' with reverse foreign key lookups. + # The related_id is alawys set even if it's not fetched from the DB, + # so pk and related_id are not deferred. + self.assert_delayed(self.s1.primary_set.all().only('pk')[0], 2) + + def test_defer_only_chaining(self): + qs = Primary.objects.all() + self.assert_delayed(qs.only("name", "value").defer("name")[0], 2) + self.assert_delayed(qs.defer("name").only("value", "name")[0], 2) + self.assert_delayed(qs.defer("name").only("value")[0], 2) + self.assert_delayed(qs.only("name").defer("value")[0], 2) + + def test_defer_on_an_already_deferred_field(self): + qs = Primary.objects.all() + self.assert_delayed(qs.defer("name")[0], 1) + self.assert_delayed(qs.defer("name").defer("name")[0], 1) + + def test_defer_none_to_clear_deferred_set(self): + qs = Primary.objects.all() + self.assert_delayed(qs.defer("name", "value")[0], 2) + self.assert_delayed(qs.defer(None)[0], 0) + self.assert_delayed(qs.only("name").defer(None)[0], 0) + + def test_only_none_raises_error(self): + msg = 'Cannot pass None as an argument to only().' + with self.assertRaisesMessage(TypeError, msg): + Primary.objects.only(None) + + def test_defer_extra(self): + qs = Primary.objects.all() + self.assert_delayed(qs.defer("name").extra(select={"a": 1})[0], 1) + self.assert_delayed(qs.extra(select={"a": 1}).defer("name")[0], 1) + + def test_defer_values_does_not_defer(self): + # User values() won't defer anything (you get the full list of + # dictionaries back), but it still works. + self.assertEqual(Primary.objects.defer("name").values()[0], { + "id": self.p1.id, + "name": "p1", + "value": "xx", + "related_id": self.s1.id, + }) + + def test_only_values_does_not_defer(self): + self.assertEqual(Primary.objects.only("name").values()[0], { + "id": self.p1.id, + "name": "p1", + "value": "xx", + "related_id": self.s1.id, + }) + + def test_get(self): + # Using defer() and only() with get() is also valid. + qs = Primary.objects.all() + self.assert_delayed(qs.defer("name").get(pk=self.p1.pk), 1) + self.assert_delayed(qs.only("name").get(pk=self.p1.pk), 2) + + def test_defer_with_select_related(self): + obj = Primary.objects.select_related().defer("related__first", "related__second")[0] + self.assert_delayed(obj.related, 2) + self.assert_delayed(obj, 0) + + def test_only_with_select_related(self): + obj = Primary.objects.select_related().only("related__first")[0] + self.assert_delayed(obj, 2) + self.assert_delayed(obj.related, 1) + self.assertEqual(obj.related_id, self.s1.pk) + self.assertEqual(obj.name, "p1") + + def test_defer_select_related_raises_invalid_query(self): + msg = ( + 'Field Primary.related cannot be both deferred and traversed ' + 'using select_related at the same time.' + ) + with self.assertRaisesMessage(InvalidQuery, msg): + Primary.objects.defer("related").select_related("related")[0] + + def test_only_select_related_raises_invalid_query(self): + msg = ( + 'Field Primary.related cannot be both deferred and traversed using ' + 'select_related at the same time.' + ) + with self.assertRaisesMessage(InvalidQuery, msg): + Primary.objects.only("name").select_related("related")[0] + + def test_defer_foreign_keys_are_deferred_and_not_traversed(self): + # TODO: fix + return + # select_related() overrides defer(). + with self.assertNumQueries(1): + obj = Primary.objects.defer("related").select_related()[0] + self.assert_delayed(obj, 1) + self.assertEqual(obj.related.id, self.s1.pk) + + def test_saving_object_with_deferred_field(self): + # Saving models with deferred fields is possible (but inefficient, + # since every field has to be retrieved first). + Primary.objects.create(name="p2", value="xy", related=self.s1) + obj = Primary.objects.defer("value").get(name="p2") + obj.name = "a new name" + obj.save() + self.assertQuerysetEqual( + Primary.objects.all(), [ + "p1", "a new name", + ], + lambda p: p.name, + ordered=False, + ) + + def test_defer_baseclass_when_subclass_has_no_added_fields(self): + # Regression for #10572 - A subclass with no extra fields can defer + # fields from the base class + Child.objects.create(name="c1", value="foo", related=self.s1) + # You can defer a field on a baseclass when the subclass has no fields + obj = Child.objects.defer("value").get(name="c1") + self.assert_delayed(obj, 1) + self.assertEqual(obj.name, "c1") + self.assertEqual(obj.value, "foo") + + def test_only_baseclass_when_subclass_has_no_added_fields(self): + # You can retrieve a single column on a base class with no fields + Child.objects.create(name="c1", value="foo", related=self.s1) + obj = Child.objects.only("name").get(name="c1") + # on an inherited model, its PK is also fetched, hence '3' deferred fields. + self.assert_delayed(obj, 3) + self.assertEqual(obj.name, "c1") + self.assertEqual(obj.value, "foo") + + +class BigChildDeferTests(AssertionMixin, TestCase): + @classmethod + def setUpTestData(cls): + cls.s1 = Secondary.objects.create(first="x1", second="y1") + BigChild.objects.create(name="b1", value="foo", related=cls.s1, other="bar") + + def test_defer_baseclass_when_subclass_has_added_field(self): + # You can defer a field on a baseclass + obj = BigChild.objects.defer("value").get(name="b1") + self.assert_delayed(obj, 1) + self.assertEqual(obj.name, "b1") + self.assertEqual(obj.value, "foo") + self.assertEqual(obj.other, "bar") + + def test_defer_subclass(self): + # You can defer a field on a subclass + obj = BigChild.objects.defer("other").get(name="b1") + self.assert_delayed(obj, 1) + self.assertEqual(obj.name, "b1") + self.assertEqual(obj.value, "foo") + self.assertEqual(obj.other, "bar") + + def test_only_baseclass_when_subclass_has_added_field(self): + # You can retrieve a single field on a baseclass + obj = BigChild.objects.only("name").get(name="b1") + # when inherited model, its PK is also fetched, hence '4' deferred fields. + self.assert_delayed(obj, 4) + self.assertEqual(obj.name, "b1") + self.assertEqual(obj.value, "foo") + self.assertEqual(obj.other, "bar") + + def test_only_sublcass(self): + # You can retrieve a single field on a subclass + obj = BigChild.objects.only("other").get(name="b1") + self.assert_delayed(obj, 4) + self.assertEqual(obj.name, "b1") + self.assertEqual(obj.value, "foo") + self.assertEqual(obj.other, "bar") + + +class TestDefer2(AssertionMixin, TestCase): + def test_defer_proxy(self): + """ + Ensure select_related together with only on a proxy model behaves + as expected. See #17876. + """ + related = Secondary.objects.create(first='x1', second='x2') + ChildProxy.objects.create(name='p1', value='xx', related=related) + children = ChildProxy.objects.all().select_related().only('id', 'name') + self.assertEqual(len(children), 1) + child = children[0] + self.assert_delayed(child, 2) + self.assertEqual(child.name, 'p1') + self.assertEqual(child.value, 'xx') + + def test_defer_inheritance_pk_chaining(self): + """ + When an inherited model is fetched from the DB, its PK is also fetched. + When getting the PK of the parent model it is useful to use the already + fetched parent model PK if it happens to be available. + """ + s1 = Secondary.objects.create(first="x1", second="y1") + bc = BigChild.objects.create(name="b1", value="foo", related=s1, + other="bar") + bc_deferred = BigChild.objects.only('name').get(pk=bc.pk) + with self.assertNumQueries(0): + bc_deferred.id + self.assertEqual(bc_deferred.pk, bc_deferred.id) + + def test_eq(self): + s1 = Secondary.objects.create(first="x1", second="y1") + s1_defer = Secondary.objects.only('pk').get(pk=s1.pk) + self.assertEqual(s1, s1_defer) + self.assertEqual(s1_defer, s1) + + def test_refresh_not_loading_deferred_fields(self): + s = Secondary.objects.create() + rf = Primary.objects.create(name='foo', value='bar', related=s) + rf2 = Primary.objects.only('related', 'value').get() + rf.name = 'new foo' + rf.value = 'new bar' + rf.save() + with self.assertNumQueries(1): + rf2.refresh_from_db() + self.assertEqual(rf2.value, 'new bar') + with self.assertNumQueries(1): + self.assertEqual(rf2.name, 'new foo') + + def test_custom_refresh_on_deferred_loading(self): + s = Secondary.objects.create() + rf = RefreshPrimaryProxy.objects.create(name='foo', value='bar', related=s) + rf2 = RefreshPrimaryProxy.objects.only('related').get() + rf.name = 'new foo' + rf.value = 'new bar' + rf.save() + with self.assertNumQueries(1): + # Customized refresh_from_db() reloads all deferred fields on + # access of any of them. + self.assertEqual(rf2.name, 'new foo') + self.assertEqual(rf2.value, 'new bar') diff --git a/tests/defer_regress/__init__.py b/tests/defer_regress/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/defer_regress/models.py b/tests/defer_regress/models.py new file mode 100644 index 00000000..a73f539b --- /dev/null +++ b/tests/defer_regress/models.py @@ -0,0 +1,106 @@ +""" +Regression tests for defer() / only() behavior. +""" + +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +@python_2_unicode_compatible +class Item(models.Model): + name = models.CharField(max_length=15) + text = models.TextField(default="xyzzy") + value = models.IntegerField() + other_value = models.IntegerField(default=0) + + def __str__(self): + return self.name + + +class RelatedItem(models.Model): + item = models.ForeignKey(Item, models.CASCADE) + + +class ProxyRelated(RelatedItem): + class Meta: + proxy = True + + +class Child(models.Model): + name = models.CharField(max_length=10) + value = models.IntegerField() + + +@python_2_unicode_compatible +class Leaf(models.Model): + name = models.CharField(max_length=10) + child = models.ForeignKey(Child, models.CASCADE) + second_child = models.ForeignKey(Child, models.SET_NULL, related_name="other", null=True) + value = models.IntegerField(default=42) + + def __str__(self): + return self.name + + +class ResolveThis(models.Model): + num = models.FloatField() + name = models.CharField(max_length=16) + + +class Proxy(Item): + class Meta: + proxy = True + + +@python_2_unicode_compatible +class SimpleItem(models.Model): + name = models.CharField(max_length=15) + value = models.IntegerField() + + def __str__(self): + return self.name + + +class Feature(models.Model): + item = models.ForeignKey(SimpleItem, models.CASCADE) + + +class SpecialFeature(models.Model): + feature = models.ForeignKey(Feature, models.CASCADE) + + +class OneToOneItem(models.Model): + item = models.OneToOneField(Item, models.CASCADE, related_name="one_to_one_item") + name = models.CharField(max_length=15) + + +class ItemAndSimpleItem(models.Model): + item = models.ForeignKey(Item, models.CASCADE) + simple = models.ForeignKey(SimpleItem, models.CASCADE) + + +class Profile(models.Model): + profile1 = models.CharField(max_length=1000, default='profile1') + + +class Location(models.Model): + location1 = models.CharField(max_length=1000, default='location1') + + +class Request(models.Model): + profile = models.ForeignKey(Profile, models.SET_NULL, null=True, blank=True) + location = models.ForeignKey(Location, models.CASCADE) + items = models.ManyToManyField(Item) + + request1 = models.CharField(default='request1', max_length=1000) + request2 = models.CharField(default='request2', max_length=1000) + request3 = models.CharField(default='request3', max_length=1000) + request4 = models.CharField(default='request4', max_length=1000) + + +class Base(models.Model): + text = models.TextField() + + +class Derived(Base): + other_text = models.TextField() diff --git a/tests/defer_regress/tests.py b/tests/defer_regress/tests.py new file mode 100644 index 00000000..76bdf277 --- /dev/null +++ b/tests/defer_regress/tests.py @@ -0,0 +1,282 @@ +from __future__ import unicode_literals + +from operator import attrgetter + +import django +from django.contrib.contenttypes.models import ContentType +from django.contrib.sessions.backends.db import SessionStore +from django.db import models +from django.db.models import Count +from django.test import TestCase, override_settings + +from .models import ( + Base, Child, Derived, Feature, Item, ItemAndSimpleItem, Leaf, Location, + OneToOneItem, Proxy, ProxyRelated, RelatedItem, Request, ResolveThis, + SimpleItem, SpecialFeature, +) + + +class DeferRegressionTest(TestCase): + def test_basic(self): + # Deferred fields should really be deferred and not accidentally use + # the field's default value just because they aren't passed to __init__ + + Item.objects.create(name="first", value=42) + obj = Item.objects.only("name", "other_value").get(name="first") + # Accessing "name" doesn't trigger a new database query. Accessing + # "value" or "text" should. + with self.assertNumQueries(0): + self.assertEqual(obj.name, "first") + self.assertEqual(obj.other_value, 0) + + with self.assertNumQueries(1): + self.assertEqual(obj.value, 42) + + with self.assertNumQueries(1): + self.assertEqual(obj.text, "xyzzy") + + with self.assertNumQueries(0): + self.assertEqual(obj.text, "xyzzy") + + # Regression test for #10695. Make sure different instances don't + # inadvertently share data in the deferred descriptor objects. + i = Item.objects.create(name="no I'm first", value=37) + items = Item.objects.only("value").order_by("-value") + self.assertEqual(items[0].name, "first") + self.assertEqual(items[1].name, "no I'm first") + + RelatedItem.objects.create(item=i) + r = RelatedItem.objects.defer("item").get() + self.assertEqual(r.item_id, i.id) + self.assertEqual(r.item, i) + + # Some further checks for select_related() and inherited model + # behavior (regression for #10710). + c1 = Child.objects.create(name="c1", value=42) + c2 = Child.objects.create(name="c2", value=37) + Leaf.objects.create(name="l1", child=c1, second_child=c2) + + obj = Leaf.objects.only("name", "child").select_related()[0] + self.assertEqual(obj.child.name, "c1") + + self.assertQuerysetEqual( + Leaf.objects.select_related().only("child__name", "second_child__name"), [ + "l1", + ], + attrgetter("name") + ) + + # Models instances with deferred fields should still return the same + # content types as their non-deferred versions (bug #10738). + ctype = ContentType.objects.get_for_model + c1 = ctype(Item.objects.all()[0]) + c2 = ctype(Item.objects.defer("name")[0]) + c3 = ctype(Item.objects.only("name")[0]) + self.assertTrue(c1 is c2 is c3) + + # Regression for #10733 - only() can be used on a model with two + # foreign keys. + results = Leaf.objects.only("name", "child", "second_child").select_related() + self.assertEqual(results[0].child.name, "c1") + self.assertEqual(results[0].second_child.name, "c2") + + results = Leaf.objects.only( + "name", "child", "second_child", "child__name", "second_child__name" + ).select_related() + self.assertEqual(results[0].child.name, "c1") + self.assertEqual(results[0].second_child.name, "c2") + + # Regression for #16409 - make sure defer() and only() work with annotate() + self.assertIsInstance( + list(SimpleItem.objects.annotate(Count('feature')).defer('name')), + list) + self.assertIsInstance( + list(SimpleItem.objects.annotate(Count('feature')).only('name')), + list) + + @override_settings(SESSION_SERIALIZER='django.contrib.sessions.serializers.PickleSerializer') + def test_ticket_12163(self): + # Test for #12163 - Pickling error saving session with unsaved model + # instances. + SESSION_KEY = '2b1189a188b44ad18c35e1baac6ceead' + + item = Item() + item._deferred = False + s = SessionStore(SESSION_KEY) + s.clear() + s["item"] = item + s.save(must_create=True) + + s = SessionStore(SESSION_KEY) + s.modified = True + s.save() + + i2 = s["item"] + self.assertFalse(i2._deferred) + + def test_ticket_16409(self): + # Regression for #16409 - make sure defer() and only() work with annotate() + self.assertIsInstance( + list(SimpleItem.objects.annotate(Count('feature')).defer('name')), + list) + self.assertIsInstance( + list(SimpleItem.objects.annotate(Count('feature')).only('name')), + list) + + def test_ticket_23270(self): + Derived.objects.create(text="foo", other_text="bar") + with self.assertNumQueries(1): + obj = Base.objects.select_related("derived").defer("text")[0] + self.assertIsInstance(obj.derived, Derived) + self.assertEqual("bar", obj.derived.other_text) + self.assertNotIn("text", obj.__dict__) + self.assertEqual(1, obj.derived.base_ptr_id) + + def test_only_and_defer_usage_on_proxy_models(self): + # Regression for #15790 - only() broken for proxy models + proxy = Proxy.objects.create(name="proxy", value=42) + + msg = 'QuerySet.only() return bogus results with proxy models' + dp = Proxy.objects.only('other_value').get(pk=proxy.pk) + self.assertEqual(dp.name, proxy.name, msg=msg) + self.assertEqual(dp.value, proxy.value, msg=msg) + + # also test things with .defer() + msg = 'QuerySet.defer() return bogus results with proxy models' + dp = Proxy.objects.defer('name', 'text', 'value').get(pk=proxy.pk) + self.assertEqual(dp.name, proxy.name, msg=msg) + self.assertEqual(dp.value, proxy.value, msg=msg) + + def test_resolve_columns(self): + ResolveThis.objects.create(num=5.0, name='Foobar') + qs = ResolveThis.objects.defer('num') + self.assertEqual(1, qs.count()) + self.assertEqual('Foobar', qs[0].name) + + def test_reverse_one_to_one_relations(self): + # Refs #14694. Test reverse relations which are known unique (reverse + # side has o2ofield or unique FK) - the o2o case + item = Item.objects.create(name="first", value=42) + o2o = OneToOneItem.objects.create(item=item, name="second") + self.assertEqual(len(Item.objects.defer('one_to_one_item__name')), 1) + self.assertEqual(len(Item.objects.select_related('one_to_one_item')), 1) + self.assertEqual(len(Item.objects.select_related( + 'one_to_one_item').defer('one_to_one_item__name')), 1) + self.assertEqual(len(Item.objects.select_related('one_to_one_item').defer('value')), 1) + # Make sure that `only()` doesn't break when we pass in a unique relation, + # rather than a field on the relation. + self.assertEqual(len(Item.objects.only('one_to_one_item')), 1) + with self.assertNumQueries(1): + i = Item.objects.select_related('one_to_one_item')[0] + self.assertEqual(i.one_to_one_item.pk, o2o.pk) + self.assertEqual(i.one_to_one_item.name, "second") + with self.assertNumQueries(1): + i = Item.objects.select_related('one_to_one_item').defer( + 'value', 'one_to_one_item__name')[0] + self.assertEqual(i.one_to_one_item.pk, o2o.pk) + self.assertEqual(i.name, "first") + with self.assertNumQueries(1): + self.assertEqual(i.one_to_one_item.name, "second") + with self.assertNumQueries(1): + self.assertEqual(i.value, 42) + + def test_defer_with_select_related(self): + item1 = Item.objects.create(name="first", value=47) + item2 = Item.objects.create(name="second", value=42) + simple = SimpleItem.objects.create(name="simple", value="23") + ItemAndSimpleItem.objects.create(item=item1, simple=simple) + + obj = ItemAndSimpleItem.objects.defer('item').select_related('simple').get() + self.assertEqual(obj.item, item1) + self.assertEqual(obj.item_id, item1.id) + + obj.item = item2 + obj.save() + + obj = ItemAndSimpleItem.objects.defer('item').select_related('simple').get() + self.assertEqual(obj.item, item2) + self.assertEqual(obj.item_id, item2.id) + + def test_proxy_model_defer_with_select_related(self): + # Regression for #22050 + item = Item.objects.create(name="first", value=47) + RelatedItem.objects.create(item=item) + # Defer fields with only() + obj = ProxyRelated.objects.all().select_related().only('item__name')[0] + with self.assertNumQueries(0): + self.assertEqual(obj.item.name, "first") + with self.assertNumQueries(1): + self.assertEqual(obj.item.value, 47) + + def test_only_with_select_related(self): + # Test for #17485. + item = SimpleItem.objects.create(name='first', value=47) + feature = Feature.objects.create(item=item) + SpecialFeature.objects.create(feature=feature) + + qs = Feature.objects.only('item__name').select_related('item') + self.assertEqual(len(qs), 1) + + qs = SpecialFeature.objects.only('feature__item__name').select_related('feature__item') + self.assertEqual(len(qs), 1) + + +class DeferAnnotateSelectRelatedTest(TestCase): + def test_defer_annotate_select_related(self): + location = Location.objects.create() + Request.objects.create(location=location) + self.assertIsInstance( + list(Request.objects.annotate(Count('items')).select_related('profile', 'location') + .only('profile', 'location')), + list + ) + self.assertIsInstance( + list(Request.objects.annotate(Count('items')).select_related('profile', 'location') + .only('profile__profile1', 'location__location1')), + list + ) + self.assertIsInstance( + list(Request.objects.annotate(Count('items')).select_related('profile', 'location') + .defer('request1', 'request2', 'request3', 'request4')), + list + ) + + +class DeferDeletionSignalsTests(TestCase): + senders = [Item, Proxy] + + @classmethod + def setUpTestData(cls): + cls.item_pk = Item.objects.create(value=1).pk + + def setUp(self): + self.pre_delete_senders = [] + self.post_delete_senders = [] + for sender in self.senders: + models.signals.pre_delete.connect(self.pre_delete_receiver, sender) + models.signals.post_delete.connect(self.post_delete_receiver, sender) + + def tearDown(self): + for sender in self.senders: + models.signals.pre_delete.disconnect(self.pre_delete_receiver, sender) + models.signals.post_delete.disconnect(self.post_delete_receiver, sender) + + def pre_delete_receiver(self, sender, **kwargs): + self.pre_delete_senders.append(sender) + + def post_delete_receiver(self, sender, **kwargs): + self.post_delete_senders.append(sender) + + def test_delete_defered_model(self): + if django.VERSION < (1, 10, 0): + self.skipTest('This does not work on older Django') + Item.objects.only('value').get(pk=self.item_pk).delete() + self.assertEqual(self.pre_delete_senders, [Item]) + self.assertEqual(self.post_delete_senders, [Item]) + + def test_delete_defered_proxy_model(self): + if django.VERSION < (1, 10, 0): + self.skipTest('This does not work on older Django') + Proxy.objects.only('value').get(pk=self.item_pk).delete() + self.assertEqual(self.pre_delete_senders, [Proxy]) + self.assertEqual(self.post_delete_senders, [Proxy]) diff --git a/tests/delete_regress/__init__.py b/tests/delete_regress/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/delete_regress/models.py b/tests/delete_regress/models.py new file mode 100644 index 00000000..f0145de6 --- /dev/null +++ b/tests/delete_regress/models.py @@ -0,0 +1,141 @@ +from django.contrib.contenttypes.fields import ( + GenericForeignKey, GenericRelation, +) +from django.contrib.contenttypes.models import ContentType +from django.db import models + + +class Award(models.Model): + name = models.CharField(max_length=25) + object_id = models.PositiveIntegerField() + content_type = models.ForeignKey(ContentType, models.CASCADE) + content_object = GenericForeignKey() + + +class AwardNote(models.Model): + award = models.ForeignKey(Award, models.CASCADE) + note = models.CharField(max_length=100) + + +class Person(models.Model): + name = models.CharField(max_length=25) + awards = GenericRelation(Award) + + +class Book(models.Model): + pagecount = models.IntegerField() + + +class Toy(models.Model): + name = models.CharField(max_length=50) + + +class Child(models.Model): + name = models.CharField(max_length=50) + toys = models.ManyToManyField(Toy, through='PlayedWith') + + +class PlayedWith(models.Model): + child = models.ForeignKey(Child, models.CASCADE) + toy = models.ForeignKey(Toy, models.CASCADE) + date = models.DateField(db_column='date_col') + + +class PlayedWithNote(models.Model): + played = models.ForeignKey(PlayedWith, models.CASCADE) + note = models.TextField() + + +class Contact(models.Model): + label = models.CharField(max_length=100) + + +class Email(Contact): + email_address = models.EmailField(max_length=100) + + +class Researcher(models.Model): + contacts = models.ManyToManyField(Contact, related_name="research_contacts") + + +class Food(models.Model): + name = models.CharField(max_length=20, unique=True) + + +class Eaten(models.Model): + food = models.ForeignKey(Food, models.CASCADE, to_field="name") + meal = models.CharField(max_length=20) + + +# Models for #15776 + + +class Policy(models.Model): + policy_number = models.CharField(max_length=10) + + +class Version(models.Model): + policy = models.ForeignKey(Policy, models.CASCADE) + + +class Location(models.Model): + version = models.ForeignKey(Version, models.SET_NULL, blank=True, null=True) + + +class Item(models.Model): + version = models.ForeignKey(Version, models.CASCADE) + location = models.ForeignKey(Location, models.SET_NULL, blank=True, null=True) + +# Models for #16128 + + +class File(models.Model): + pass + + +class Image(File): + class Meta: + proxy = True + + +class Photo(Image): + class Meta: + proxy = True + + +class FooImage(models.Model): + my_image = models.ForeignKey(Image, models.CASCADE) + + +class FooFile(models.Model): + my_file = models.ForeignKey(File, models.CASCADE) + + +class FooPhoto(models.Model): + my_photo = models.ForeignKey(Photo, models.CASCADE) + + +class FooFileProxy(FooFile): + class Meta: + proxy = True + + +class OrgUnit(models.Model): + name = models.CharField(max_length=64, unique=True) + + +class Login(models.Model): + description = models.CharField(max_length=32) + orgunit = models.ForeignKey(OrgUnit, models.CASCADE) + + +class House(models.Model): + address = models.CharField(max_length=32) + + +class OrderedPerson(models.Model): + name = models.CharField(max_length=32) + lives_in = models.ForeignKey(House, models.CASCADE) + + class Meta: + ordering = ['name'] diff --git a/tests/delete_regress/tests.py b/tests/delete_regress/tests.py new file mode 100644 index 00000000..21287337 --- /dev/null +++ b/tests/delete_regress/tests.py @@ -0,0 +1,347 @@ +from __future__ import unicode_literals + +import datetime + +from django.db import connection, models, transaction +from django.test import TestCase, TransactionTestCase, skipUnlessDBFeature + +from .models import ( + Award, AwardNote, Book, Child, Eaten, Email, File, Food, FooFile, + FooFileProxy, FooImage, FooPhoto, House, Image, Item, Location, Login, + OrderedPerson, OrgUnit, Person, Photo, PlayedWith, PlayedWithNote, Policy, + Researcher, Toy, Version, +) + + +# Can't run this test under SQLite, because you can't +# get two connections to an in-memory database. +@skipUnlessDBFeature('test_db_allows_multiple_connections') +class DeleteLockingTest(TransactionTestCase): + + available_apps = ['delete_regress'] + + def setUp(self): + # Create a second connection to the default database + self.conn2 = connection.copy() + self.conn2.set_autocommit(False) + + def tearDown(self): + # Close down the second connection. + self.conn2.rollback() + self.conn2.close() + + def test_concurrent_delete(self): + """Concurrent deletes don't collide and lock the database (#9479).""" + with transaction.atomic(): + Book.objects.create(id=1, pagecount=100) + Book.objects.create(id=2, pagecount=200) + Book.objects.create(id=3, pagecount=300) + + with transaction.atomic(): + # Start a transaction on the main connection. + self.assertEqual(3, Book.objects.count()) + + # Delete something using another database connection. + with self.conn2.cursor() as cursor2: + cursor2.execute("DELETE from delete_regress_book WHERE id = 1") + self.conn2.commit() + + # In the same transaction on the main connection, perform a + # queryset delete that covers the object deleted with the other + # connection. This causes an infinite loop under MySQL InnoDB + # unless we keep track of already deleted objects. + Book.objects.filter(pagecount__lt=250).delete() + + self.assertEqual(1, Book.objects.count()) + + +class DeleteCascadeTests(TestCase): + def test_generic_relation_cascade(self): + """ + Django cascades deletes through generic-related objects to their + reverse relations. + """ + person = Person.objects.create(name='Nelson Mandela') + award = Award.objects.create(name='Nobel', content_object=person) + AwardNote.objects.create(note='a peace prize', + award=award) + self.assertEqual(AwardNote.objects.count(), 1) + person.delete() + self.assertEqual(Award.objects.count(), 0) + # first two asserts are just sanity checks, this is the kicker: + self.assertEqual(AwardNote.objects.count(), 0) + + def test_fk_to_m2m_through(self): + """ + If an M2M relationship has an explicitly-specified through model, and + some other model has an FK to that through model, deletion is cascaded + from one of the participants in the M2M, to the through model, to its + related model. + """ + juan = Child.objects.create(name='Juan') + paints = Toy.objects.create(name='Paints') + played = PlayedWith.objects.create(child=juan, toy=paints, + date=datetime.date.today()) + PlayedWithNote.objects.create(played=played, + note='the next Jackson Pollock') + self.assertEqual(PlayedWithNote.objects.count(), 1) + paints.delete() + self.assertEqual(PlayedWith.objects.count(), 0) + # first two asserts just sanity checks, this is the kicker: + self.assertEqual(PlayedWithNote.objects.count(), 0) + + def test_15776(self): + policy = Policy.objects.create(pk=1, policy_number="1234") + version = Version.objects.create(policy=policy) + location = Location.objects.create(version=version) + Item.objects.create(version=version, location=location) + policy.delete() + + +class DeleteCascadeTransactionTests(TransactionTestCase): + + available_apps = ['delete_regress'] + + def test_inheritance(self): + """ + Auto-created many-to-many through tables referencing a parent model are + correctly found by the delete cascade when a child of that parent is + deleted. + + Refs #14896. + """ + r = Researcher.objects.create() + email = Email.objects.create( + label="office-email", email_address="carl@science.edu" + ) + r.contacts.add(email) + + email.delete() + + def test_to_field(self): + """ + Cascade deletion works with ForeignKey.to_field set to non-PK. + """ + apple = Food.objects.create(name="apple") + Eaten.objects.create(food=apple, meal="lunch") + + apple.delete() + self.assertFalse(Food.objects.exists()) + self.assertFalse(Eaten.objects.exists()) + + +class LargeDeleteTests(TestCase): + def test_large_deletes(self): + "Regression for #13309 -- if the number of objects > chunk size, deletion still occurs" + for x in range(300): + Book.objects.create(pagecount=x + 100) + # attach a signal to make sure we will not fast-delete + + def noop(*args, **kwargs): + pass + models.signals.post_delete.connect(noop, sender=Book) + Book.objects.all().delete() + models.signals.post_delete.disconnect(noop, sender=Book) + self.assertEqual(Book.objects.count(), 0) + + +class ProxyDeleteTest(TestCase): + """ + Tests on_delete behavior for proxy models. + + See #16128. + """ + def create_image(self): + """Return an Image referenced by both a FooImage and a FooFile.""" + # Create an Image + test_image = Image() + test_image.save() + foo_image = FooImage(my_image=test_image) + foo_image.save() + + # Get the Image instance as a File + test_file = File.objects.get(pk=test_image.pk) + foo_file = FooFile(my_file=test_file) + foo_file.save() + + return test_image + + def test_delete_proxy(self): + """ + Deleting the *proxy* instance bubbles through to its non-proxy and + *all* referring objects are deleted. + """ + self.create_image() + + Image.objects.all().delete() + + # An Image deletion == File deletion + self.assertEqual(len(Image.objects.all()), 0) + self.assertEqual(len(File.objects.all()), 0) + + # The Image deletion cascaded and *all* references to it are deleted. + self.assertEqual(len(FooImage.objects.all()), 0) + self.assertEqual(len(FooFile.objects.all()), 0) + + def test_delete_proxy_of_proxy(self): + """ + Deleting a proxy-of-proxy instance should bubble through to its proxy + and non-proxy parents, deleting *all* referring objects. + """ + test_image = self.create_image() + + # Get the Image as a Photo + test_photo = Photo.objects.get(pk=test_image.pk) + foo_photo = FooPhoto(my_photo=test_photo) + foo_photo.save() + + Photo.objects.all().delete() + + # A Photo deletion == Image deletion == File deletion + self.assertEqual(len(Photo.objects.all()), 0) + self.assertEqual(len(Image.objects.all()), 0) + self.assertEqual(len(File.objects.all()), 0) + + # The Photo deletion should have cascaded and deleted *all* + # references to it. + self.assertEqual(len(FooPhoto.objects.all()), 0) + self.assertEqual(len(FooFile.objects.all()), 0) + self.assertEqual(len(FooImage.objects.all()), 0) + + def test_delete_concrete_parent(self): + """ + Deleting an instance of a concrete model should also delete objects + referencing its proxy subclass. + """ + self.create_image() + + File.objects.all().delete() + + # A File deletion == Image deletion + self.assertEqual(len(File.objects.all()), 0) + self.assertEqual(len(Image.objects.all()), 0) + + # The File deletion should have cascaded and deleted *all* references + # to it. + self.assertEqual(len(FooFile.objects.all()), 0) + self.assertEqual(len(FooImage.objects.all()), 0) + + def test_delete_proxy_pair(self): + """ + If a pair of proxy models are linked by an FK from one concrete parent + to the other, deleting one proxy model cascade-deletes the other, and + the deletion happens in the right order (not triggering an + IntegrityError on databases unable to defer integrity checks). + + Refs #17918. + """ + # Create an Image (proxy of File) and FooFileProxy (proxy of FooFile, + # which has an FK to File) + image = Image.objects.create() + as_file = File.objects.get(pk=image.pk) + FooFileProxy.objects.create(my_file=as_file) + + Image.objects.all().delete() + + self.assertEqual(len(FooFileProxy.objects.all()), 0) + + def test_19187_values(self): + with self.assertRaises(TypeError): + Image.objects.values().delete() + with self.assertRaises(TypeError): + Image.objects.values_list().delete() + + +class Ticket19102Tests(TestCase): + """ + Test different queries which alter the SELECT clause of the query. We + also must be using a subquery for the deletion (that is, the original + query has a join in it). The deletion should be done as "fast-path" + deletion (that is, just one query for the .delete() call). + + Note that .values() is not tested here on purpose. .values().delete() + doesn't work for non fast-path deletes at all. + """ + def setUp(self): + self.o1 = OrgUnit.objects.create(name='o1') + self.o2 = OrgUnit.objects.create(name='o2') + self.l1 = Login.objects.create(description='l1', orgunit=self.o1) + self.l2 = Login.objects.create(description='l2', orgunit=self.o2) + + @skipUnlessDBFeature("update_can_self_select") + def test_ticket_19102_annotate(self): + with self.assertNumQueries(1): + Login.objects.order_by('description').filter( + orgunit__name__isnull=False + ).annotate( + n=models.Count('description') + ).filter( + n=1, pk=self.l1.pk + ).delete() + self.assertFalse(Login.objects.filter(pk=self.l1.pk).exists()) + self.assertTrue(Login.objects.filter(pk=self.l2.pk).exists()) + + @skipUnlessDBFeature("update_can_self_select") + def test_ticket_19102_extra(self): + with self.assertNumQueries(1): + Login.objects.order_by('description').filter( + orgunit__name__isnull=False + ).extra( + select={'extraf': '1'} + ).filter( + pk=self.l1.pk + ).delete() + self.assertFalse(Login.objects.filter(pk=self.l1.pk).exists()) + self.assertTrue(Login.objects.filter(pk=self.l2.pk).exists()) + + @skipUnlessDBFeature("update_can_self_select") + @skipUnlessDBFeature('can_distinct_on_fields') + def test_ticket_19102_distinct_on(self): + # Both Login objs should have same description so that only the one + # having smaller PK will be deleted. + Login.objects.update(description='description') + with self.assertNumQueries(1): + Login.objects.distinct('description').order_by('pk').filter( + orgunit__name__isnull=False + ).delete() + # Assumed that l1 which is created first has smaller PK. + self.assertFalse(Login.objects.filter(pk=self.l1.pk).exists()) + self.assertTrue(Login.objects.filter(pk=self.l2.pk).exists()) + + @skipUnlessDBFeature("update_can_self_select") + def test_ticket_19102_select_related(self): + with self.assertNumQueries(1): + Login.objects.filter( + pk=self.l1.pk + ).filter( + orgunit__name__isnull=False + ).order_by( + 'description' + ).select_related('orgunit').delete() + self.assertFalse(Login.objects.filter(pk=self.l1.pk).exists()) + self.assertTrue(Login.objects.filter(pk=self.l2.pk).exists()) + + @skipUnlessDBFeature("update_can_self_select") + def test_ticket_19102_defer(self): + with self.assertNumQueries(1): + Login.objects.filter( + pk=self.l1.pk + ).filter( + orgunit__name__isnull=False + ).order_by( + 'description' + ).only('id').delete() + self.assertFalse(Login.objects.filter(pk=self.l1.pk).exists()) + self.assertTrue(Login.objects.filter(pk=self.l2.pk).exists()) + + +class OrderedDeleteTests(TestCase): + def test_meta_ordered_delete(self): + # When a subquery is performed by deletion code, the subquery must be + # cleared of all ordering. There was a but that caused _meta ordering + # to be used. Refs #19720. + h = House.objects.create(address='Foo') + OrderedPerson.objects.create(name='Jack', lives_in=h) + OrderedPerson.objects.create(name='Bob', lives_in=h) + OrderedPerson.objects.filter(lives_in__address='Foo').delete() + self.assertEqual(OrderedPerson.objects.count(), 0) diff --git a/tests/distinct_on_fields/__init__.py b/tests/distinct_on_fields/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/distinct_on_fields/models.py b/tests/distinct_on_fields/models.py new file mode 100644 index 00000000..2c33f3ad --- /dev/null +++ b/tests/distinct_on_fields/models.py @@ -0,0 +1,61 @@ +from __future__ import unicode_literals + +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +@python_2_unicode_compatible +class Tag(models.Model): + name = models.CharField(max_length=10) + parent = models.ForeignKey( + 'self', + models.SET_NULL, + blank=True, + null=True, + related_name='children', + ) + + class Meta: + ordering = ['name'] + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class Celebrity(models.Model): + name = models.CharField("Name", max_length=20) + greatest_fan = models.ForeignKey( + "Fan", + models.SET_NULL, + null=True, + unique=True, + ) + + def __str__(self): + return self.name + + +class Fan(models.Model): + fan_of = models.ForeignKey(Celebrity, models.CASCADE) + + +@python_2_unicode_compatible +class Staff(models.Model): + id = models.IntegerField(primary_key=True) + name = models.CharField(max_length=50) + organisation = models.CharField(max_length=100) + tags = models.ManyToManyField(Tag, through='StaffTag') + coworkers = models.ManyToManyField('self') + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class StaffTag(models.Model): + staff = models.ForeignKey(Staff, models.CASCADE) + tag = models.ForeignKey(Tag, models.CASCADE) + + def __str__(self): + return "%s -> %s" % (self.tag, self.staff) diff --git a/tests/distinct_on_fields/tests.py b/tests/distinct_on_fields/tests.py new file mode 100644 index 00000000..e7445003 --- /dev/null +++ b/tests/distinct_on_fields/tests.py @@ -0,0 +1,130 @@ +from __future__ import unicode_literals + +from django.db.models import Max +from django.test import TestCase, skipUnlessDBFeature +from django.test.utils import str_prefix + +from .models import Celebrity, Fan, Staff, StaffTag, Tag + + +@skipUnlessDBFeature('can_distinct_on_fields') +@skipUnlessDBFeature('supports_nullable_unique_constraints') +class DistinctOnTests(TestCase): + def setUp(self): + t1 = Tag.objects.create(name='t1') + Tag.objects.create(name='t2', parent=t1) + t3 = Tag.objects.create(name='t3', parent=t1) + Tag.objects.create(name='t4', parent=t3) + Tag.objects.create(name='t5', parent=t3) + + self.p1_o1 = Staff.objects.create(id=1, name="p1", organisation="o1") + self.p2_o1 = Staff.objects.create(id=2, name="p2", organisation="o1") + self.p3_o1 = Staff.objects.create(id=3, name="p3", organisation="o1") + self.p1_o2 = Staff.objects.create(id=4, name="p1", organisation="o2") + self.p1_o1.coworkers.add(self.p2_o1, self.p3_o1) + StaffTag.objects.create(staff=self.p1_o1, tag=t1) + StaffTag.objects.create(staff=self.p1_o1, tag=t1) + + celeb1 = Celebrity.objects.create(name="c1") + celeb2 = Celebrity.objects.create(name="c2") + + self.fan1 = Fan.objects.create(fan_of=celeb1) + self.fan2 = Fan.objects.create(fan_of=celeb1) + self.fan3 = Fan.objects.create(fan_of=celeb2) + + def test_basic_distinct_on(self): + """QuerySet.distinct('field', ...) works""" + # (qset, expected) tuples + qsets = ( + ( + Staff.objects.distinct().order_by('name'), + ['', '', '', ''], + ), + ( + Staff.objects.distinct('name').order_by('name'), + ['', '', ''], + ), + ( + Staff.objects.distinct('organisation').order_by('organisation', 'name'), + ['', ''], + ), + ( + Staff.objects.distinct('name', 'organisation').order_by('name', 'organisation'), + ['', '', '', ''], + ), + ( + Celebrity.objects.filter(fan__in=[self.fan1, self.fan2, self.fan3]).distinct('name').order_by('name'), + ['', ''], + ), + # Does combining querysets work? + ( + (Celebrity.objects.filter(fan__in=[self.fan1, self.fan2]). + distinct('name').order_by('name') | + Celebrity.objects.filter(fan__in=[self.fan3]). + distinct('name').order_by('name')), + ['', ''], + ), + ( + StaffTag.objects.distinct('staff', 'tag'), + [' p1>'], + ), + ( + Tag.objects.order_by('parent__pk', 'pk').distinct('parent'), + ['', '', ''], + ), + ( + StaffTag.objects.select_related('staff').distinct('staff__name').order_by('staff__name'), + [' p1>'], + ), + # Fetch the alphabetically first coworker for each worker + ( + (Staff.objects.distinct('id').order_by('id', 'coworkers__name'). + values_list('id', 'coworkers__name')), + [str_prefix("(1, %(_)s'p2')"), str_prefix("(2, %(_)s'p1')"), + str_prefix("(3, %(_)s'p1')"), "(4, None)"] + ), + ) + for qset, expected in qsets: + self.assertQuerysetEqual(qset, expected) + self.assertEqual(qset.count(), len(expected)) + + # Combining queries with different distinct_fields is not allowed. + base_qs = Celebrity.objects.all() + with self.assertRaisesMessage(AssertionError, "Cannot combine queries with different distinct fields."): + base_qs.distinct('id') & base_qs.distinct('name') + + # Test join unreffing + c1 = Celebrity.objects.distinct('greatest_fan__id', 'greatest_fan__fan_of') + self.assertIn('OUTER JOIN', str(c1.query)) + c2 = c1.distinct('pk') + self.assertNotIn('OUTER JOIN', str(c2.query)) + + def test_distinct_not_implemented_checks(self): + # distinct + annotate not allowed + with self.assertRaises(NotImplementedError): + Celebrity.objects.annotate(Max('id')).distinct('id')[0] + with self.assertRaises(NotImplementedError): + Celebrity.objects.distinct('id').annotate(Max('id'))[0] + + # However this check is done only when the query executes, so you + # can use distinct() to remove the fields before execution. + Celebrity.objects.distinct('id').annotate(Max('id')).distinct()[0] + # distinct + aggregate not allowed + with self.assertRaises(NotImplementedError): + Celebrity.objects.distinct('id').aggregate(Max('id')) + + def test_distinct_on_in_ordered_subquery(self): + qs = Staff.objects.distinct('name').order_by('name', 'id') + qs = Staff.objects.filter(pk__in=qs).order_by('name') + self.assertSequenceEqual(qs, [self.p1_o1, self.p2_o1, self.p3_o1]) + qs = Staff.objects.distinct('name').order_by('name', '-id') + qs = Staff.objects.filter(pk__in=qs).order_by('name') + self.assertSequenceEqual(qs, [self.p1_o2, self.p2_o1, self.p3_o1]) + + def test_distinct_on_get_ordering_preserved(self): + """ + Ordering shouldn't be cleared when distinct on fields are specified. + refs #25081 + """ + staff = Staff.objects.distinct('name').order_by('name', '-organisation').get(name='p1') + self.assertEqual(staff.organisation, 'o2') diff --git a/tests/expressions/__init__.py b/tests/expressions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/expressions/models.py b/tests/expressions/models.py new file mode 100644 index 00000000..42e4a37b --- /dev/null +++ b/tests/expressions/models.py @@ -0,0 +1,95 @@ +""" +Tests for F() query expression syntax. +""" +import uuid + +from django.db import models + + +class Employee(models.Model): + firstname = models.CharField(max_length=50) + lastname = models.CharField(max_length=50) + salary = models.IntegerField(blank=True, null=True) + + def __str__(self): + return '%s %s' % (self.firstname, self.lastname) + + +class Company(models.Model): + name = models.CharField(max_length=100) + num_employees = models.PositiveIntegerField() + num_chairs = models.PositiveIntegerField() + ceo = models.ForeignKey( + Employee, + models.CASCADE, + related_name='company_ceo_set', + ) + point_of_contact = models.ForeignKey( + Employee, + models.SET_NULL, + related_name='company_point_of_contact_set', + null=True, + ) + + def __str__(self): + return self.name + + +class Number(models.Model): + integer = models.BigIntegerField(db_column='the_integer') + float = models.FloatField(null=True, db_column='the_float') + + def __str__(self): + return '%i, %.3f' % (self.integer, self.float) + + +class Experiment(models.Model): + name = models.CharField(max_length=24) + assigned = models.DateField() + completed = models.DateField() + estimated_time = models.DurationField() + start = models.DateTimeField() + end = models.DateTimeField() + + class Meta: + db_table = 'expressions_ExPeRiMeNt' + ordering = ('name',) + + def duration(self): + return self.end - self.start + + +class Result(models.Model): + experiment = models.ForeignKey(Experiment, models.CASCADE) + result_time = models.DateTimeField() + + def __str__(self): + return "Result at %s" % self.result_time + + +class Time(models.Model): + time = models.TimeField(null=True) + + def __str__(self): + return "%s" % self.time + + +class SimulationRun(models.Model): + start = models.ForeignKey(Time, models.CASCADE, null=True, related_name='+') + end = models.ForeignKey(Time, models.CASCADE, null=True, related_name='+') + midpoint = models.TimeField() + + def __str__(self): + return "%s (%s to %s)" % (self.midpoint, self.start, self.end) + + +class UUIDPK(models.Model): + id = models.UUIDField(primary_key=True, default=uuid.uuid4) + + +class UUID(models.Model): + uuid = models.UUIDField(null=True) + uuid_fk = models.ForeignKey(UUIDPK, models.CASCADE, null=True) + + def __str__(self): + return "%s" % self.uuid diff --git a/tests/expressions/test_queryset_values.py b/tests/expressions/test_queryset_values.py new file mode 100644 index 00000000..e2645979 --- /dev/null +++ b/tests/expressions/test_queryset_values.py @@ -0,0 +1,62 @@ +from django.db.models.aggregates import Sum +from django.db.models.expressions import F +from django.test import TestCase + +from .models import Company, Employee + + +class ValuesExpressionsTests(TestCase): + @classmethod + def setUpTestData(cls): + Company.objects.create( + name='Example Inc.', num_employees=2300, num_chairs=5, + ceo=Employee.objects.create(firstname='Joe', lastname='Smith', salary=10) + ) + Company.objects.create( + name='Foobar Ltd.', num_employees=3, num_chairs=4, + ceo=Employee.objects.create(firstname='Frank', lastname='Meyer', salary=20) + ) + Company.objects.create( + name='Test GmbH', num_employees=32, num_chairs=1, + ceo=Employee.objects.create(firstname='Max', lastname='Mustermann', salary=30) + ) + + def test_values_expression(self): + self.assertSequenceEqual( + Company.objects.values(salary=F('ceo__salary')), + [{'salary': 10}, {'salary': 20}, {'salary': 30}], + ) + + def test_values_expression_group_by(self): + # values() applies annotate() first, so values selected are grouped by + # id, not firstname. + Employee.objects.create(firstname='Joe', lastname='Jones', salary=2) + joes = Employee.objects.filter(firstname='Joe') + self.assertSequenceEqual( + joes.values('firstname', sum_salary=Sum('salary')).order_by('sum_salary'), + [{'firstname': 'Joe', 'sum_salary': 2}, {'firstname': 'Joe', 'sum_salary': 10}], + ) + self.assertSequenceEqual( + joes.values('firstname').annotate(sum_salary=Sum('salary')), + [{'firstname': 'Joe', 'sum_salary': 12}] + ) + + def test_chained_values_with_expression(self): + Employee.objects.create(firstname='Joe', lastname='Jones', salary=2) + joes = Employee.objects.filter(firstname='Joe').values('firstname') + self.assertSequenceEqual( + joes.values('firstname', sum_salary=Sum('salary')), + [{'firstname': 'Joe', 'sum_salary': 12}] + ) + self.assertSequenceEqual( + joes.values(sum_salary=Sum('salary')), + [{'sum_salary': 12}] + ) + + def test_values_list_expression(self): + companies = Company.objects.values_list('name', F('ceo__salary')) + self.assertSequenceEqual(companies, [('Example Inc.', 10), ('Foobar Ltd.', 20), ('Test GmbH', 30)]) + + def test_values_list_expression_flat(self): + companies = Company.objects.values_list(F('ceo__salary'), flat=True) + self.assertSequenceEqual(companies, (10, 20, 30)) diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py new file mode 100644 index 00000000..4789a244 --- /dev/null +++ b/tests/expressions/tests.py @@ -0,0 +1,1519 @@ +import datetime +import pickle +import unittest +import uuid +from copy import deepcopy + +from django.core.exceptions import FieldError +from django.db import DatabaseError, connection, models, transaction +from django.db.models import CharField, Q, TimeField, UUIDField +from django.db.models.aggregates import ( + Avg, Count, Max, Min, StdDev, Sum, Variance, +) +from django.db.models.expressions import ( + Case, Col, Combinable, Exists, ExpressionList, ExpressionWrapper, F, Func, + OrderBy, OuterRef, Random, RawSQL, Ref, Subquery, Value, When, +) +from django.db.models.functions import ( + Coalesce, Concat, Length, Lower, Substr, Upper, +) +from django.db.models.sql import constants +from django.db.models.sql.datastructures import Join +from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature +from django.test.utils import Approximate + +from .models import ( + UUID, UUIDPK, Company, Employee, Experiment, Number, Result, SimulationRun, + Time, +) + + +class BasicExpressionsTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.example_inc = Company.objects.create( + name="Example Inc.", num_employees=2300, num_chairs=5, + ceo=Employee.objects.create(firstname="Joe", lastname="Smith", salary=10) + ) + cls.foobar_ltd = Company.objects.create( + name="Foobar Ltd.", num_employees=3, num_chairs=4, + ceo=Employee.objects.create(firstname="Frank", lastname="Meyer", salary=20) + ) + cls.max = Employee.objects.create(firstname='Max', lastname='Mustermann', salary=30) + cls.gmbh = Company.objects.create(name='Test GmbH', num_employees=32, num_chairs=1, ceo=cls.max) + + def setUp(self): + self.company_query = Company.objects.values( + "name", "num_employees", "num_chairs" + ).order_by( + "name", "num_employees", "num_chairs" + ) + + def test_annotate_values_aggregate(self): + companies = Company.objects.annotate( + salaries=F('ceo__salary'), + ).values('num_employees', 'salaries').aggregate( + result=Sum( + F('salaries') + F('num_employees'), + output_field=models.IntegerField() + ), + ) + self.assertEqual(companies['result'], 2395) + + def test_annotate_values_filter(self): + companies = Company.objects.annotate( + foo=RawSQL('%s', ['value']), + ).filter(foo='value').order_by('name') + self.assertQuerysetEqual( + companies, [ + '', + '', + '', + ], + ) + + #@unittest.skipIf(connection.vendor == 'oracle', "Oracle doesn't support using boolean type in SELECT") + #def test_filtering_on_annotate_that_uses_q(self): + # self.assertEqual( + # Company.objects.annotate( + # num_employees_check=ExpressionWrapper(Q(num_employees__gt=3), output_field=models.BooleanField()) + # ).filter(num_employees_check=True).count(), + # 2, + # ) + + def test_filter_inter_attribute(self): + # We can filter on attribute relationships on same model obj, e.g. + # find companies where the number of employees is greater + # than the number of chairs. + self.assertSequenceEqual( + self.company_query.filter(num_employees__gt=F("num_chairs")), [ + { + "num_chairs": 5, + "name": "Example Inc.", + "num_employees": 2300, + }, + { + "num_chairs": 1, + "name": "Test GmbH", + "num_employees": 32 + }, + ], + ) + + def test_update(self): + # We can set one field to have the value of another field + # Make sure we have enough chairs + self.company_query.update(num_chairs=F("num_employees")) + self.assertSequenceEqual( + self.company_query, [ + { + "num_chairs": 2300, + "name": "Example Inc.", + "num_employees": 2300 + }, + { + "num_chairs": 3, + "name": "Foobar Ltd.", + "num_employees": 3 + }, + { + "num_chairs": 32, + "name": "Test GmbH", + "num_employees": 32 + } + ], + ) + + def test_arithmetic(self): + # We can perform arithmetic operations in expressions + # Make sure we have 2 spare chairs + self.company_query.update(num_chairs=F("num_employees") + 2) + self.assertSequenceEqual( + self.company_query, [ + { + 'num_chairs': 2302, + 'name': 'Example Inc.', + 'num_employees': 2300 + }, + { + 'num_chairs': 5, + 'name': 'Foobar Ltd.', + 'num_employees': 3 + }, + { + 'num_chairs': 34, + 'name': 'Test GmbH', + 'num_employees': 32 + } + ], + ) + + def test_order_of_operations(self): + # Law of order of operations is followed + self. company_query.update( + num_chairs=F('num_employees') + 2 * F('num_employees') + ) + self.assertSequenceEqual( + self.company_query, [ + { + 'num_chairs': 6900, + 'name': 'Example Inc.', + 'num_employees': 2300 + }, + { + 'num_chairs': 9, + 'name': 'Foobar Ltd.', + 'num_employees': 3 + }, + { + 'num_chairs': 96, + 'name': 'Test GmbH', + 'num_employees': 32 + } + ], + ) + + def test_parenthesis_priority(self): + # Law of order of operations can be overridden by parentheses + self.company_query.update( + num_chairs=((F('num_employees') + 2) * F('num_employees')) + ) + self.assertSequenceEqual( + self.company_query, [ + { + 'num_chairs': 5294600, + 'name': 'Example Inc.', + 'num_employees': 2300 + }, + { + 'num_chairs': 15, + 'name': 'Foobar Ltd.', + 'num_employees': 3 + }, + { + 'num_chairs': 1088, + 'name': 'Test GmbH', + 'num_employees': 32 + } + ], + ) + + def test_update_with_fk(self): + # ForeignKey can become updated with the value of another ForeignKey. + self.assertEqual( + Company.objects.update(point_of_contact=F('ceo')), + 3 + ) + self.assertQuerysetEqual( + Company.objects.all(), [ + "Joe Smith", + "Frank Meyer", + "Max Mustermann", + ], + lambda c: str(c.point_of_contact), + ordered=False + ) + + def test_update_with_none(self): + Number.objects.create(integer=1, float=1.0) + Number.objects.create(integer=2) + Number.objects.filter(float__isnull=False).update(float=Value(None)) + self.assertQuerysetEqual( + Number.objects.all(), [ + None, + None, + ], + lambda n: n.float, + ordered=False + ) + + def test_filter_with_join(self): + # F Expressions can also span joins + Company.objects.update(point_of_contact=F('ceo')) + c = Company.objects.all()[0] + c.point_of_contact = Employee.objects.create(firstname="Guido", lastname="van Rossum") + c.save() + + self.assertQuerysetEqual( + Company.objects.filter(ceo__firstname=F("point_of_contact__firstname")), [ + "Foobar Ltd.", + "Test GmbH", + ], + lambda c: c.name, + ordered=False + ) + + Company.objects.exclude( + ceo__firstname=F("point_of_contact__firstname") + ).update(name="foo") + self.assertEqual( + Company.objects.exclude( + ceo__firstname=F('point_of_contact__firstname') + ).get().name, + "foo", + ) + + with transaction.atomic(): + msg = "Joined field references are not permitted in this query" + with self.assertRaisesMessage(FieldError, msg): + Company.objects.exclude( + ceo__firstname=F('point_of_contact__firstname') + ).update(name=F('point_of_contact__lastname')) + + def test_object_update(self): + # F expressions can be used to update attributes on single objects + test_gmbh = Company.objects.get(name="Test GmbH") + self.assertEqual(test_gmbh.num_employees, 32) + test_gmbh.num_employees = F("num_employees") + 4 + test_gmbh.save() + test_gmbh = Company.objects.get(pk=test_gmbh.pk) + self.assertEqual(test_gmbh.num_employees, 36) + + def test_new_object_save(self): + # We should be able to use Funcs when inserting new data + test_co = Company( + name=Lower(Value("UPPER")), num_employees=32, num_chairs=1, + ceo=Employee.objects.create(firstname="Just", lastname="Doit", salary=30), + ) + test_co.save() + test_co.refresh_from_db() + self.assertEqual(test_co.name, "upper") + + def test_new_object_create(self): + test_co = Company.objects.create( + name=Lower(Value("UPPER")), num_employees=32, num_chairs=1, + ceo=Employee.objects.create(firstname="Just", lastname="Doit", salary=30), + ) + test_co.refresh_from_db() + self.assertEqual(test_co.name, "upper") + + def test_object_create_with_aggregate(self): + # Aggregates are not allowed when inserting new data + with self.assertRaisesMessage(FieldError, 'Aggregate functions are not allowed in this query'): + Company.objects.create( + name='Company', num_employees=Max(Value(1)), num_chairs=1, + ceo=Employee.objects.create(firstname="Just", lastname="Doit", salary=30), + ) + + def test_object_update_fk(self): + # F expressions cannot be used to update attributes which are foreign + # keys, or attributes which involve joins. + test_gmbh = Company.objects.get(name="Test GmbH") + + def test(): + test_gmbh.point_of_contact = F("ceo") + msg = 'F(ceo)": "Company.point_of_contact" must be a "Employee" instance.' + with self.assertRaisesMessage(ValueError, msg): + test() + + test_gmbh.point_of_contact = test_gmbh.ceo + test_gmbh.save() + test_gmbh.name = F("ceo__last_name") + msg = 'Joined field references are not permitted in this query' + with self.assertRaisesMessage(FieldError, msg): + test_gmbh.save() + + def test_object_update_unsaved_objects(self): + # F expressions cannot be used to update attributes on objects which do + # not yet exist in the database + test_gmbh = Company.objects.get(name="Test GmbH") + acme = Company( + name="The Acme Widget Co.", num_employees=12, num_chairs=5, + ceo=test_gmbh.ceo + ) + acme.num_employees = F("num_employees") + 16 + msg = ( + 'Failed to insert expression "Col(expressions_company, ' + 'expressions.Company.num_employees) + Value(16)" on ' + 'expressions.Company.num_employees. F() expressions can only be ' + 'used to update, not to insert.' + ) + with self.assertRaisesMessage(ValueError, msg): + acme.save() + + acme.num_employees = 12 + acme.name = Lower(F('name')) + msg = ( + 'Failed to insert expression "Lower(Col(expressions_company, ' + 'expressions.Company.name))" on expressions.Company.name. F() ' + 'expressions can only be used to update, not to insert.' + ) + with self.assertRaisesMessage(ValueError, msg): + acme.save() + + def test_ticket_11722_iexact_lookup(self): + Employee.objects.create(firstname="John", lastname="Doe") + Employee.objects.create(firstname="Test", lastname="test") + + queryset = Employee.objects.filter(firstname__iexact=F('lastname')) + self.assertQuerysetEqual(queryset, [""]) + + def test_ticket_16731_startswith_lookup(self): + Employee.objects.create(firstname="John", lastname="Doe") + e2 = Employee.objects.create(firstname="Jack", lastname="Jackson") + e3 = Employee.objects.create(firstname="Jack", lastname="jackson") + self.assertSequenceEqual( + Employee.objects.filter(lastname__startswith=F('firstname')), + [e2, e3] if connection.features.has_case_insensitive_like else [e2] + ) + qs = Employee.objects.filter(lastname__istartswith=F('firstname')).order_by('pk') + self.assertSequenceEqual(qs, [e2, e3]) + + def test_ticket_18375_join_reuse(self): + # Reverse multijoin F() references and the lookup target the same join. + # Pre #18375 the F() join was generated first and the lookup couldn't + # reuse that join. + qs = Employee.objects.filter( + company_ceo_set__num_chairs=F('company_ceo_set__num_employees')) + self.assertEqual(str(qs.query).count('JOIN'), 1) + + def test_ticket_18375_kwarg_ordering(self): + # The next query was dict-randomization dependent - if the "gte=1" + # was seen first, then the F() will reuse the join generated by the + # gte lookup, if F() was seen first, then it generated a join the + # other lookups could not reuse. + qs = Employee.objects.filter( + company_ceo_set__num_chairs=F('company_ceo_set__num_employees'), + company_ceo_set__num_chairs__gte=1, + ) + self.assertEqual(str(qs.query).count('JOIN'), 1) + + def test_ticket_18375_kwarg_ordering_2(self): + # Another similar case for F() than above. Now we have the same join + # in two filter kwargs, one in the lhs lookup, one in F. Here pre + # #18375 the amount of joins generated was random if dict + # randomization was enabled, that is the generated query dependent + # on which clause was seen first. + qs = Employee.objects.filter( + company_ceo_set__num_employees=F('pk'), + pk=F('company_ceo_set__num_employees') + ) + self.assertEqual(str(qs.query).count('JOIN'), 1) + + def test_ticket_18375_chained_filters(self): + # F() expressions do not reuse joins from previous filter. + qs = Employee.objects.filter( + company_ceo_set__num_employees=F('pk') + ).filter( + company_ceo_set__num_employees=F('company_ceo_set__num_employees') + ) + self.assertEqual(str(qs.query).count('JOIN'), 2) + + def test_order_by_exists(self): + mary = Employee.objects.create(firstname='Mary', lastname='Mustermann', salary=20) + mustermanns_by_seniority = Employee.objects.filter(lastname='Mustermann').order_by( + # Order by whether the employee is the CEO of a company + Exists(Company.objects.filter(ceo=OuterRef('pk'))).desc() + ) + self.assertSequenceEqual(mustermanns_by_seniority, [self.max, mary]) + + def test_outerref(self): + inner = Company.objects.filter(point_of_contact=OuterRef('pk')) + msg = ( + 'This queryset contains a reference to an outer query and may only ' + 'be used in a subquery.' + ) + with self.assertRaisesMessage(ValueError, msg): + inner.exists() + + outer = Employee.objects.annotate(is_point_of_contact=Exists(inner)) + self.assertIs(outer.exists(), True) + + def test_exist_single_field_output_field(self): + queryset = Company.objects.values('pk') + self.assertIsInstance(Exists(queryset).output_field, models.BooleanField) + + def test_subquery(self): + Company.objects.filter(name='Example Inc.').update( + point_of_contact=Employee.objects.get(firstname='Joe', lastname='Smith'), + ceo=Employee.objects.get(firstname='Max', lastname='Mustermann'), + ) + Employee.objects.create(firstname='Bob', lastname='Brown', salary=40) + qs = Employee.objects.annotate( + is_point_of_contact=Exists(Company.objects.filter(point_of_contact=OuterRef('pk'))), + is_not_point_of_contact=~Exists(Company.objects.filter(point_of_contact=OuterRef('pk'))), + is_ceo_of_small_company=Exists(Company.objects.filter(num_employees__lt=200, ceo=OuterRef('pk'))), + is_ceo_small_2=~~Exists(Company.objects.filter(num_employees__lt=200, ceo=OuterRef('pk'))), + largest_company=Subquery(Company.objects.order_by('-num_employees').filter( + models.Q(ceo=OuterRef('pk')) | models.Q(point_of_contact=OuterRef('pk')) + ).values('name')[:1], output_field=models.CharField()) + ).values( + 'firstname', + 'is_point_of_contact', + 'is_not_point_of_contact', + 'is_ceo_of_small_company', + 'is_ceo_small_2', + 'largest_company', + ).order_by('firstname') + + results = list(qs) + # Could use Coalesce(subq, Value('')) instead except for the bug in + # cx_Oracle mentioned in #23843. + bob = results[0] + if bob['largest_company'] == '' and connection.features.interprets_empty_strings_as_nulls: + bob['largest_company'] = None + + self.assertEqual(results, [ + { + 'firstname': 'Bob', + 'is_point_of_contact': False, + 'is_not_point_of_contact': True, + 'is_ceo_of_small_company': False, + 'is_ceo_small_2': False, + 'largest_company': None, + }, + { + 'firstname': 'Frank', + 'is_point_of_contact': False, + 'is_not_point_of_contact': True, + 'is_ceo_of_small_company': True, + 'is_ceo_small_2': True, + 'largest_company': 'Foobar Ltd.', + }, + { + 'firstname': 'Joe', + 'is_point_of_contact': True, + 'is_not_point_of_contact': False, + 'is_ceo_of_small_company': False, + 'is_ceo_small_2': False, + 'largest_company': 'Example Inc.', + }, + { + 'firstname': 'Max', + 'is_point_of_contact': False, + 'is_not_point_of_contact': True, + 'is_ceo_of_small_company': True, + 'is_ceo_small_2': True, + 'largest_company': 'Example Inc.' + } + ]) + # A less elegant way to write the same query: this uses a LEFT OUTER + # JOIN and an IS NULL, inside a WHERE NOT IN which is probably less + # efficient than EXISTS. + self.assertCountEqual( + qs.filter(is_point_of_contact=True).values('pk'), + Employee.objects.exclude(company_point_of_contact_set=None).values('pk') + ) + + def test_in_subquery(self): + # This is a contrived test (and you really wouldn't write this query), + # but it is a succinct way to test the __in=Subquery() construct. + small_companies = Company.objects.filter(num_employees__lt=200).values('pk') + subquery_test = Company.objects.filter(pk__in=Subquery(small_companies)) + self.assertCountEqual(subquery_test, [self.foobar_ltd, self.gmbh]) + subquery_test2 = Company.objects.filter(pk=Subquery(small_companies.filter(num_employees=3))) + self.assertCountEqual(subquery_test2, [self.foobar_ltd]) + + def test_uuid_pk_subquery(self): + u = UUIDPK.objects.create() + UUID.objects.create(uuid_fk=u) + qs = UUIDPK.objects.filter(id__in=Subquery(UUID.objects.values('uuid_fk__id'))) + self.assertCountEqual(qs, [u]) + + def test_nested_subquery(self): + inner = Company.objects.filter(point_of_contact=OuterRef('pk')) + outer = Employee.objects.annotate(is_point_of_contact=Exists(inner)) + contrived = Employee.objects.annotate( + is_point_of_contact=Subquery( + outer.filter(pk=OuterRef('pk')).values('is_point_of_contact'), + output_field=models.BooleanField(), + ), + ) + self.assertCountEqual(contrived.values_list(), outer.values_list()) + + def test_nested_subquery_outer_ref_2(self): + first = Time.objects.create(time='09:00') + second = Time.objects.create(time='17:00') + third = Time.objects.create(time='21:00') + SimulationRun.objects.bulk_create([ + SimulationRun(start=first, end=second, midpoint='12:00'), + SimulationRun(start=first, end=third, midpoint='15:00'), + SimulationRun(start=second, end=first, midpoint='00:00'), + ]) + inner = Time.objects.filter(time=OuterRef(OuterRef('time')), pk=OuterRef('start')).values('time') + middle = SimulationRun.objects.annotate(other=Subquery(inner)).values('other')[:1] + outer = Time.objects.annotate(other=Subquery(middle, output_field=models.TimeField())) + # This is a contrived example. It exercises the double OuterRef form. + self.assertCountEqual(outer, [first, second, third]) + + def test_nested_subquery_outer_ref_with_autofield(self): + first = Time.objects.create(time='09:00') + second = Time.objects.create(time='17:00') + SimulationRun.objects.create(start=first, end=second, midpoint='12:00') + inner = SimulationRun.objects.filter(start=OuterRef(OuterRef('pk'))).values('start') + middle = Time.objects.annotate(other=Subquery(inner)).values('other')[:1] + outer = Time.objects.annotate(other=Subquery(middle, output_field=models.IntegerField())) + # This exercises the double OuterRef form with AutoField as pk. + self.assertCountEqual(outer, [first, second]) + + def test_annotations_within_subquery(self): + Company.objects.filter(num_employees__lt=50).update(ceo=Employee.objects.get(firstname='Frank')) + inner = Company.objects.filter( + ceo=OuterRef('pk') + ).values('ceo').annotate(total_employees=models.Sum('num_employees')).values('total_employees') + outer = Employee.objects.annotate(total_employees=Subquery(inner)).filter(salary__lte=Subquery(inner)) + self.assertSequenceEqual( + outer.order_by('-total_employees').values('salary', 'total_employees'), + [{'salary': 10, 'total_employees': 2300}, {'salary': 20, 'total_employees': 35}], + ) + + def test_subquery_references_joined_table_twice(self): + inner = Company.objects.filter( + num_chairs__gte=OuterRef('ceo__salary'), + num_employees__gte=OuterRef('point_of_contact__salary'), + ) + # Another contrived example (there is no need to have a subquery here) + outer = Company.objects.filter(pk__in=Subquery(inner.values('pk'))) + self.assertFalse(outer.exists()) + + def test_explicit_output_field(self): + class FuncA(Func): + output_field = models.CharField() + + class FuncB(Func): + pass + + expr = FuncB(FuncA()) + self.assertEqual(expr.output_field, FuncA.output_field) + + def test_outerref_mixed_case_table_name(self): + inner = Result.objects.filter(result_time__gte=OuterRef('experiment__assigned')) + outer = Result.objects.filter(pk__in=Subquery(inner.values('pk'))) + self.assertFalse(outer.exists()) + + def test_outerref_with_operator(self): + inner = Company.objects.filter(num_employees=OuterRef('ceo__salary') + 2) + outer = Company.objects.filter(pk__in=Subquery(inner.values('pk'))) + self.assertEqual(outer.get().name, 'Test GmbH') + + def test_pickle_expression(self): + expr = Value(1, output_field=models.IntegerField()) + expr.convert_value # populate cached property + self.assertEqual(pickle.loads(pickle.dumps(expr)), expr) + + +class IterableLookupInnerExpressionsTests(TestCase): + @classmethod + def setUpTestData(cls): + ceo = Employee.objects.create(firstname='Just', lastname='Doit', salary=30) + # MySQL requires that the values calculated for expressions don't pass + # outside of the field's range, so it's inconvenient to use the values + # in the more general tests. + Company.objects.create(name='5020 Ltd', num_employees=50, num_chairs=20, ceo=ceo) + Company.objects.create(name='5040 Ltd', num_employees=50, num_chairs=40, ceo=ceo) + Company.objects.create(name='5050 Ltd', num_employees=50, num_chairs=50, ceo=ceo) + Company.objects.create(name='5060 Ltd', num_employees=50, num_chairs=60, ceo=ceo) + Company.objects.create(name='99300 Ltd', num_employees=99, num_chairs=300, ceo=ceo) + + def test_in_lookup_allows_F_expressions_and_expressions_for_integers(self): + # __in lookups can use F() expressions for integers. + queryset = Company.objects.filter(num_employees__in=([F('num_chairs') - 10])) + self.assertQuerysetEqual(queryset, [''], ordered=False) + self.assertQuerysetEqual( + Company.objects.filter(num_employees__in=([F('num_chairs') - 10, F('num_chairs') + 10])), + ['', ''], + ordered=False + ) + self.assertQuerysetEqual( + Company.objects.filter( + num_employees__in=([F('num_chairs') - 10, F('num_chairs'), F('num_chairs') + 10]) + ), + ['', '', ''], + ordered=False + ) + + def test_expressions_in_lookups_join_choice(self): + self.skipTest('failing on MSSQL') + midpoint = datetime.time(13, 0) + t1 = Time.objects.create(time=datetime.time(12, 0)) + t2 = Time.objects.create(time=datetime.time(14, 0)) + SimulationRun.objects.create(start=t1, end=t2, midpoint=midpoint) + SimulationRun.objects.create(start=t1, end=None, midpoint=midpoint) + SimulationRun.objects.create(start=None, end=t2, midpoint=midpoint) + SimulationRun.objects.create(start=None, end=None, midpoint=midpoint) + + queryset = SimulationRun.objects.filter(midpoint__range=[F('start__time'), F('end__time')]) + self.assertQuerysetEqual( + queryset, + [''], + ordered=False + ) + for alias in queryset.query.alias_map.values(): + if isinstance(alias, Join): + self.assertEqual(alias.join_type, constants.INNER) + + queryset = SimulationRun.objects.exclude(midpoint__range=[F('start__time'), F('end__time')]) + self.assertQuerysetEqual(queryset, [], ordered=False) + for alias in queryset.query.alias_map.values(): + if isinstance(alias, Join): + self.assertEqual(alias.join_type, constants.LOUTER) + + def test_range_lookup_allows_F_expressions_and_expressions_for_integers(self): + # Range lookups can use F() expressions for integers. + Company.objects.filter(num_employees__exact=F("num_chairs")) + self.assertQuerysetEqual( + Company.objects.filter(num_employees__range=(F('num_chairs'), 100)), + ['', '', ''], + ordered=False + ) + self.assertQuerysetEqual( + Company.objects.filter(num_employees__range=(F('num_chairs') - 10, F('num_chairs') + 10)), + ['', '', ''], + ordered=False + ) + self.assertQuerysetEqual( + Company.objects.filter(num_employees__range=(F('num_chairs') - 10, 100)), + ['', '', '', ''], + ordered=False + ) + self.assertQuerysetEqual( + Company.objects.filter(num_employees__range=(1, 100)), + [ + '', '', '', + '', '', + ], + ordered=False + ) + + @unittest.skipUnless(connection.vendor == 'sqlite', + "This defensive test only works on databases that don't validate parameter types") + def test_complex_expressions_do_not_introduce_sql_injection_via_untrusted_string_inclusion(self): + """ + This tests that SQL injection isn't possible using compilation of + expressions in iterable filters, as their compilation happens before + the main query compilation. It's limited to SQLite, as PostgreSQL, + Oracle and other vendors have defense in depth against this by type + checking. Testing against SQLite (the most permissive of the built-in + databases) demonstrates that the problem doesn't exist while keeping + the test simple. + """ + queryset = Company.objects.filter(name__in=[F('num_chairs') + '1)) OR ((1==1']) + self.assertQuerysetEqual(queryset, [], ordered=False) + + def test_in_lookup_allows_F_expressions_and_expressions_for_datetimes(self): + start = datetime.datetime(2016, 2, 3, 15, 0, 0) + end = datetime.datetime(2016, 2, 5, 15, 0, 0) + experiment_1 = Experiment.objects.create( + name='Integrity testing', + assigned=start.date(), + start=start, + end=end, + completed=end.date(), + estimated_time=end - start, + ) + experiment_2 = Experiment.objects.create( + name='Taste testing', + assigned=start.date(), + start=start, + end=end, + completed=end.date(), + estimated_time=end - start, + ) + Result.objects.create( + experiment=experiment_1, + result_time=datetime.datetime(2016, 2, 4, 15, 0, 0), + ) + Result.objects.create( + experiment=experiment_1, + result_time=datetime.datetime(2016, 3, 10, 2, 0, 0), + ) + Result.objects.create( + experiment=experiment_2, + result_time=datetime.datetime(2016, 1, 8, 5, 0, 0), + ) + + within_experiment_time = [F('experiment__start'), F('experiment__end')] + queryset = Result.objects.filter(result_time__range=within_experiment_time) + self.assertQuerysetEqual(queryset, [""]) + + within_experiment_time = [F('experiment__start'), F('experiment__end')] + queryset = Result.objects.filter(result_time__range=within_experiment_time) + self.assertQuerysetEqual(queryset, [""]) + + +class FTests(SimpleTestCase): + + def test_deepcopy(self): + f = F("foo") + g = deepcopy(f) + self.assertEqual(f.name, g.name) + + def test_deconstruct(self): + f = F('name') + path, args, kwargs = f.deconstruct() + self.assertEqual(path, 'django.db.models.expressions.F') + self.assertEqual(args, (f.name,)) + self.assertEqual(kwargs, {}) + + def test_equal(self): + f = F('name') + same_f = F('name') + other_f = F('username') + self.assertEqual(f, same_f) + self.assertNotEqual(f, other_f) + + def test_hash(self): + d = {F('name'): 'Bob'} + self.assertIn(F('name'), d) + self.assertEqual(d[F('name')], 'Bob') + + def test_not_equal_Value(self): + f = F('name') + value = Value('name') + self.assertNotEqual(f, value) + self.assertNotEqual(value, f) + + +class ExpressionsTests(TestCase): + + def test_F_reuse(self): + f = F('id') + n = Number.objects.create(integer=-1) + c = Company.objects.create( + name="Example Inc.", num_employees=2300, num_chairs=5, + ceo=Employee.objects.create(firstname="Joe", lastname="Smith") + ) + c_qs = Company.objects.filter(id=f) + self.assertEqual(c_qs.get(), c) + # Reuse the same F-object for another queryset + n_qs = Number.objects.filter(id=f) + self.assertEqual(n_qs.get(), n) + # The original query still works correctly + self.assertEqual(c_qs.get(), c) + + def test_patterns_escape(self): + r""" + Special characters (e.g. %, _ and \) stored in database are + properly escaped when using a pattern lookup with an expression + refs #16731 + """ + Employee.objects.bulk_create([ + Employee(firstname="%Joh\\nny", lastname="%Joh\\n"), + Employee(firstname="Johnny", lastname="%John"), + Employee(firstname="Jean-Claude", lastname="Claud_"), + Employee(firstname="Jean-Claude", lastname="Claude"), + Employee(firstname="Jean-Claude", lastname="Claude%"), + Employee(firstname="Johnny", lastname="Joh\\n"), + Employee(firstname="Johnny", lastname="John"), + Employee(firstname="Johnny", lastname="_ohn"), + ]) + + self.assertQuerysetEqual( + Employee.objects.filter(firstname__contains=F('lastname')), + ["", "", ""], + ordered=False, + ) + + self.assertQuerysetEqual( + Employee.objects.filter(firstname__startswith=F('lastname')), + ["", ""], + ordered=False, + ) + + self.assertQuerysetEqual( + Employee.objects.filter(firstname__endswith=F('lastname')), + [""], + ordered=False, + ) + + def test_insensitive_patterns_escape(self): + r""" + Special characters (e.g. %, _ and \) stored in database are + properly escaped when using a case insensitive pattern lookup with an + expression -- refs #16731 + """ + Employee.objects.bulk_create([ + Employee(firstname="%Joh\\nny", lastname="%joh\\n"), + Employee(firstname="Johnny", lastname="%john"), + Employee(firstname="Jean-Claude", lastname="claud_"), + Employee(firstname="Jean-Claude", lastname="claude"), + Employee(firstname="Jean-Claude", lastname="claude%"), + Employee(firstname="Johnny", lastname="joh\\n"), + Employee(firstname="Johnny", lastname="john"), + Employee(firstname="Johnny", lastname="_ohn"), + ]) + + self.assertQuerysetEqual( + Employee.objects.filter(firstname__icontains=F('lastname')), + ["", "", ""], + ordered=False, + ) + + self.assertQuerysetEqual( + Employee.objects.filter(firstname__istartswith=F('lastname')), + ["", ""], + ordered=False, + ) + + self.assertQuerysetEqual( + Employee.objects.filter(firstname__iendswith=F('lastname')), + [""], + ordered=False, + ) + + +class ExpressionsNumericTests(TestCase): + + def setUp(self): + Number(integer=-1).save() + Number(integer=42).save() + Number(integer=1337).save() + self.assertEqual(Number.objects.update(float=F('integer')), 3) + + def test_fill_with_value_from_same_object(self): + """ + We can fill a value in all objects with an other value of the + same object. + """ + self.assertQuerysetEqual( + Number.objects.all(), + [ + '', + '', + '' + ], + ordered=False + ) + + def test_increment_value(self): + """ + We can increment a value of all objects in a query set. + """ + self.assertEqual( + Number.objects.filter(integer__gt=0) + .update(integer=F('integer') + 1), + 2) + + self.assertQuerysetEqual( + Number.objects.all(), + [ + '', + '', + '' + ], + ordered=False + ) + + def test_filter_not_equals_other_field(self): + """ + We can filter for objects, where a value is not equals the value + of an other field. + """ + self.assertEqual( + Number.objects.filter(integer__gt=0) + .update(integer=F('integer') + 1), + 2) + self.assertQuerysetEqual( + Number.objects.exclude(float=F('integer')), + [ + '', + '' + ], + ordered=False + ) + + def test_complex_expressions(self): + """ + Complex expressions of different connection types are possible. + """ + n = Number.objects.create(integer=10, float=123.45) + self.assertEqual(Number.objects.filter(pk=n.pk).update( + float=F('integer') + F('float') * 2), 1) + + self.assertEqual(Number.objects.get(pk=n.pk).integer, 10) + self.assertEqual(Number.objects.get(pk=n.pk).float, Approximate(256.900, places=3)) + + def test_incorrect_field_expression(self): + with self.assertRaisesMessage(FieldError, "Cannot resolve keyword 'nope' into field."): + list(Employee.objects.filter(firstname=F('nope'))) + + +class ExpressionOperatorTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.n = Number.objects.create(integer=42, float=15.5) + cls.n1 = Number.objects.create(integer=-42, float=-15.5) + + def test_lefthand_addition(self): + # LH Addition of floats and integers + Number.objects.filter(pk=self.n.pk).update( + integer=F('integer') + 15, + float=F('float') + 42.7 + ) + + self.assertEqual(Number.objects.get(pk=self.n.pk).integer, 57) + self.assertEqual(Number.objects.get(pk=self.n.pk).float, Approximate(58.200, places=3)) + + def test_lefthand_subtraction(self): + # LH Subtraction of floats and integers + Number.objects.filter(pk=self.n.pk).update(integer=F('integer') - 15, float=F('float') - 42.7) + + self.assertEqual(Number.objects.get(pk=self.n.pk).integer, 27) + self.assertEqual(Number.objects.get(pk=self.n.pk).float, Approximate(-27.200, places=3)) + + def test_lefthand_multiplication(self): + # Multiplication of floats and integers + Number.objects.filter(pk=self.n.pk).update(integer=F('integer') * 15, float=F('float') * 42.7) + + self.assertEqual(Number.objects.get(pk=self.n.pk).integer, 630) + self.assertEqual(Number.objects.get(pk=self.n.pk).float, Approximate(661.850, places=3)) + + def test_lefthand_division(self): + # LH Division of floats and integers + Number.objects.filter(pk=self.n.pk).update(integer=F('integer') / 2, float=F('float') / 42.7) + + self.assertEqual(Number.objects.get(pk=self.n.pk).integer, 21) + self.assertEqual(Number.objects.get(pk=self.n.pk).float, Approximate(0.363, places=3)) + + def test_lefthand_modulo(self): + # LH Modulo arithmetic on integers + Number.objects.filter(pk=self.n.pk).update(integer=F('integer') % 20) + + self.assertEqual(Number.objects.get(pk=self.n.pk).integer, 2) + self.assertEqual(Number.objects.get(pk=self.n.pk).float, Approximate(15.500, places=3)) + + def test_lefthand_bitwise_and(self): + # LH Bitwise ands on integers + Number.objects.filter(pk=self.n.pk).update(integer=F('integer').bitand(56)) + Number.objects.filter(pk=self.n1.pk).update(integer=F('integer').bitand(-56)) + + self.assertEqual(Number.objects.get(pk=self.n.pk).integer, 40) + self.assertEqual(Number.objects.get(pk=self.n1.pk).integer, -64) + self.assertEqual(Number.objects.get(pk=self.n.pk).float, Approximate(15.500, places=3)) + + def test_lefthand_bitwise_left_shift_operator(self): + Number.objects.update(integer=F('integer').bitleftshift(2)) + self.assertEqual(Number.objects.get(pk=self.n.pk).integer, 168) + self.assertEqual(Number.objects.get(pk=self.n1.pk).integer, -168) + + #def test_lefthand_bitwise_right_shift_operator(self): + # Number.objects.update(integer=F('integer').bitrightshift(2)) + # self.assertEqual(Number.objects.get(pk=self.n.pk).integer, 10) + # self.assertEqual(Number.objects.get(pk=self.n1.pk).integer, -11) + + def test_lefthand_bitwise_or(self): + # LH Bitwise or on integers + Number.objects.update(integer=F('integer').bitor(48)) + + self.assertEqual(Number.objects.get(pk=self.n.pk).integer, 58) + self.assertEqual(Number.objects.get(pk=self.n1.pk).integer, -10) + self.assertEqual(Number.objects.get(pk=self.n.pk).float, Approximate(15.500, places=3)) + + def test_lefthand_power(self): + # LH Powert arithmetic operation on floats and integers + Number.objects.filter(pk=self.n.pk).update(integer=F('integer') ** 2, float=F('float') ** 1.5) + self.assertEqual(Number.objects.get(pk=self.n.pk).integer, 1764) + self.assertEqual(Number.objects.get(pk=self.n.pk).float, Approximate(61.02, places=2)) + + def test_right_hand_addition(self): + # Right hand operators + Number.objects.filter(pk=self.n.pk).update(integer=15 + F('integer'), float=42.7 + F('float')) + + # RH Addition of floats and integers + self.assertEqual(Number.objects.get(pk=self.n.pk).integer, 57) + self.assertEqual(Number.objects.get(pk=self.n.pk).float, Approximate(58.200, places=3)) + + def test_right_hand_subtraction(self): + Number.objects.filter(pk=self.n.pk).update(integer=15 - F('integer'), float=42.7 - F('float')) + + # RH Subtraction of floats and integers + self.assertEqual(Number.objects.get(pk=self.n.pk).integer, -27) + self.assertEqual(Number.objects.get(pk=self.n.pk).float, Approximate(27.200, places=3)) + + def test_right_hand_multiplication(self): + # RH Multiplication of floats and integers + Number.objects.filter(pk=self.n.pk).update(integer=15 * F('integer'), float=42.7 * F('float')) + + self.assertEqual(Number.objects.get(pk=self.n.pk).integer, 630) + self.assertEqual(Number.objects.get(pk=self.n.pk).float, Approximate(661.850, places=3)) + + def test_right_hand_division(self): + # RH Division of floats and integers + Number.objects.filter(pk=self.n.pk).update(integer=640 / F('integer'), float=42.7 / F('float')) + + self.assertEqual(Number.objects.get(pk=self.n.pk).integer, 15) + self.assertEqual(Number.objects.get(pk=self.n.pk).float, Approximate(2.755, places=3)) + + def test_right_hand_modulo(self): + # RH Modulo arithmetic on integers + Number.objects.filter(pk=self.n.pk).update(integer=69 % F('integer')) + + self.assertEqual(Number.objects.get(pk=self.n.pk).integer, 27) + self.assertEqual(Number.objects.get(pk=self.n.pk).float, Approximate(15.500, places=3)) + + #def test_righthand_power(self): + # # RH Powert arithmetic operation on floats and integers + # Number.objects.filter(pk=self.n.pk).update(integer=2 ** F('integer'), float=1.5 ** F('float')) + # self.assertEqual(Number.objects.get(pk=self.n.pk).integer, 4398046511104) + # self.assertEqual(Number.objects.get(pk=self.n.pk).float, Approximate(536.308, places=3)) + + +class FTimeDeltaTests(TestCase): + + @classmethod + def setUpTestData(cls): + cls.sday = sday = datetime.date(2010, 6, 25) + cls.stime = stime = datetime.datetime(2010, 6, 25, 12, 15, 30, 747000) + midnight = datetime.time(0) + + delta0 = datetime.timedelta(0) + delta1 = datetime.timedelta(microseconds=253000) + delta2 = datetime.timedelta(seconds=44) + delta3 = datetime.timedelta(hours=21, minutes=8) + delta4 = datetime.timedelta(days=10) + delta5 = datetime.timedelta(days=90) + + # Test data is set so that deltas and delays will be + # strictly increasing. + cls.deltas = [] + cls.delays = [] + cls.days_long = [] + + # e0: started same day as assigned, zero duration + end = stime + delta0 + e0 = Experiment.objects.create( + name='e0', assigned=sday, start=stime, end=end, + completed=end.date(), estimated_time=delta0, + ) + cls.deltas.append(delta0) + cls.delays.append(e0.start - datetime.datetime.combine(e0.assigned, midnight)) + cls.days_long.append(e0.completed - e0.assigned) + + # e1: started one day after assigned, tiny duration, data + # set so that end time has no fractional seconds, which + # tests an edge case on sqlite. + delay = datetime.timedelta(1) + end = stime + delay + delta1 + e1 = Experiment.objects.create( + name='e1', assigned=sday, start=stime + delay, end=end, + completed=end.date(), estimated_time=delta1, + ) + cls.deltas.append(delta1) + cls.delays.append(e1.start - datetime.datetime.combine(e1.assigned, midnight)) + cls.days_long.append(e1.completed - e1.assigned) + + # e2: started three days after assigned, small duration + end = stime + delta2 + e2 = Experiment.objects.create( + name='e2', assigned=sday - datetime.timedelta(3), start=stime, + end=end, completed=end.date(), estimated_time=datetime.timedelta(hours=1), + ) + cls.deltas.append(delta2) + cls.delays.append(e2.start - datetime.datetime.combine(e2.assigned, midnight)) + cls.days_long.append(e2.completed - e2.assigned) + + # e3: started four days after assigned, medium duration + delay = datetime.timedelta(4) + end = stime + delay + delta3 + e3 = Experiment.objects.create( + name='e3', assigned=sday, start=stime + delay, end=end, + completed=end.date(), estimated_time=delta3, + ) + cls.deltas.append(delta3) + cls.delays.append(e3.start - datetime.datetime.combine(e3.assigned, midnight)) + cls.days_long.append(e3.completed - e3.assigned) + + # e4: started 10 days after assignment, long duration + end = stime + delta4 + e4 = Experiment.objects.create( + name='e4', assigned=sday - datetime.timedelta(10), start=stime, + end=end, completed=end.date(), estimated_time=delta4 - datetime.timedelta(1), + ) + cls.deltas.append(delta4) + cls.delays.append(e4.start - datetime.datetime.combine(e4.assigned, midnight)) + cls.days_long.append(e4.completed - e4.assigned) + + # e5: started a month after assignment, very long duration + delay = datetime.timedelta(30) + end = stime + delay + delta5 + e5 = Experiment.objects.create( + name='e5', assigned=sday, start=stime + delay, end=end, + completed=end.date(), estimated_time=delta5, + ) + cls.deltas.append(delta5) + cls.delays.append(e5.start - datetime.datetime.combine(e5.assigned, midnight)) + cls.days_long.append(e5.completed - e5.assigned) + + cls.expnames = [e.name for e in Experiment.objects.all()] + + def test_multiple_query_compilation(self): + # Ticket #21643 + queryset = Experiment.objects.filter(end__lt=F('start') + datetime.timedelta(hours=1)) + q1 = str(queryset.query) + q2 = str(queryset.query) + self.assertEqual(q1, q2) + + def test_query_clone(self): + # Ticket #21643 - Crash when compiling query more than once + qs = Experiment.objects.filter(end__lt=F('start') + datetime.timedelta(hours=1)) + qs2 = qs.all() + list(qs) + list(qs2) + # Intentionally no assert + + def test_delta_add(self): + for i in range(len(self.deltas)): + delta = self.deltas[i] + test_set = [e.name for e in Experiment.objects.filter(end__lt=F('start') + delta)] + self.assertEqual(test_set, self.expnames[:i]) + + test_set = [e.name for e in Experiment.objects.filter(end__lt=delta + F('start'))] + self.assertEqual(test_set, self.expnames[:i]) + + test_set = [e.name for e in Experiment.objects.filter(end__lte=F('start') + delta)] + self.assertEqual(test_set, self.expnames[:i + 1]) + + def test_delta_subtract(self): + for i in range(len(self.deltas)): + delta = self.deltas[i] + test_set = [e.name for e in Experiment.objects.filter(start__gt=F('end') - delta)] + self.assertEqual(test_set, self.expnames[:i]) + + test_set = [e.name for e in Experiment.objects.filter(start__gte=F('end') - delta)] + self.assertEqual(test_set, self.expnames[:i + 1]) + + def test_exclude(self): + for i in range(len(self.deltas)): + delta = self.deltas[i] + test_set = [e.name for e in Experiment.objects.exclude(end__lt=F('start') + delta)] + self.assertEqual(test_set, self.expnames[i:]) + + test_set = [e.name for e in Experiment.objects.exclude(end__lte=F('start') + delta)] + self.assertEqual(test_set, self.expnames[i + 1:]) + + def test_date_comparison(self): + for i in range(len(self.days_long)): + days = self.days_long[i] + test_set = [e.name for e in Experiment.objects.filter(completed__lt=F('assigned') + days)] + self.assertEqual(test_set, self.expnames[:i]) + + test_set = [e.name for e in Experiment.objects.filter(completed__lte=F('assigned') + days)] + self.assertEqual(test_set, self.expnames[:i + 1]) + + @skipUnlessDBFeature("supports_mixed_date_datetime_comparisons") + def test_mixed_comparisons1(self): + for i in range(len(self.delays)): + delay = self.delays[i] + test_set = [e.name for e in Experiment.objects.filter(assigned__gt=F('start') - delay)] + self.assertEqual(test_set, self.expnames[:i]) + + test_set = [e.name for e in Experiment.objects.filter(assigned__gte=F('start') - delay)] + self.assertEqual(test_set, self.expnames[:i + 1]) + + def test_mixed_comparisons2(self): + delays = [datetime.timedelta(delay.days) for delay in self.delays] + for i in range(len(delays)): + delay = delays[i] + test_set = [e.name for e in Experiment.objects.filter(start__lt=F('assigned') + delay)] + self.assertEqual(test_set, self.expnames[:i]) + + test_set = [ + e.name for e in Experiment.objects.filter(start__lte=F('assigned') + delay + datetime.timedelta(1)) + ] + self.assertEqual(test_set, self.expnames[:i + 1]) + + def test_delta_update(self): + for i in range(len(self.deltas)): + delta = self.deltas[i] + exps = Experiment.objects.all() + expected_durations = [e.duration() for e in exps] + expected_starts = [e.start + delta for e in exps] + expected_ends = [e.end + delta for e in exps] + + Experiment.objects.update(start=F('start') + delta, end=F('end') + delta) + exps = Experiment.objects.all() + new_starts = [e.start for e in exps] + new_ends = [e.end for e in exps] + new_durations = [e.duration() for e in exps] + self.assertEqual(expected_starts, new_starts) + self.assertEqual(expected_ends, new_ends) + self.assertEqual(expected_durations, new_durations) + + #def test_invalid_operator(self): + # with self.assertRaises(DatabaseError): + # list(Experiment.objects.filter(start=F('start') * datetime.timedelta(0))) + + def test_durationfield_add(self): + zeros = [e.name for e in Experiment.objects.filter(start=F('start') + F('estimated_time'))] + self.assertEqual(zeros, ['e0']) + + end_less = [e.name for e in Experiment.objects.filter(end__lt=F('start') + F('estimated_time'))] + self.assertEqual(end_less, ['e2']) + + delta_math = [ + e.name for e in + Experiment.objects.filter(end__gte=F('start') + F('estimated_time') + datetime.timedelta(hours=1)) + ] + self.assertEqual(delta_math, ['e4']) + + @skipUnlessDBFeature('supports_temporal_subtraction') + def test_date_subtraction(self): + queryset = Experiment.objects.annotate( + completion_duration=ExpressionWrapper( + F('completed') - F('assigned'), output_field=models.DurationField() + ) + ) + + at_least_5_days = {e.name for e in queryset.filter(completion_duration__gte=datetime.timedelta(days=5))} + self.assertEqual(at_least_5_days, {'e3', 'e4', 'e5'}) + + at_least_120_days = {e.name for e in queryset.filter(completion_duration__gte=datetime.timedelta(days=120))} + self.assertEqual(at_least_120_days, {'e5'}) + + less_than_5_days = {e.name for e in queryset.filter(completion_duration__lt=datetime.timedelta(days=5))} + self.assertEqual(less_than_5_days, {'e0', 'e1', 'e2'}) + + @skipUnlessDBFeature('supports_temporal_subtraction') + def test_time_subtraction(self): + Time.objects.create(time=datetime.time(12, 30, 15, 2345)) + queryset = Time.objects.annotate( + difference=ExpressionWrapper( + F('time') - Value(datetime.time(11, 15, 0), output_field=models.TimeField()), + output_field=models.DurationField(), + ) + ) + self.assertEqual( + queryset.get().difference, + datetime.timedelta(hours=1, minutes=15, seconds=15, microseconds=2345) + ) + + @skipUnlessDBFeature('supports_temporal_subtraction') + def test_datetime_subtraction(self): + under_estimate = [ + e.name for e in Experiment.objects.filter(estimated_time__gt=F('end') - F('start')) + ] + self.assertEqual(under_estimate, ['e2']) + + over_estimate = [ + e.name for e in Experiment.objects.filter(estimated_time__lt=F('end') - F('start')) + ] + self.assertEqual(over_estimate, ['e4']) + + #@skipUnlessDBFeature('supports_temporal_subtraction') + #def test_datetime_subtraction_microseconds(self): + # delta = datetime.timedelta(microseconds=8999999999999999) + # Experiment.objects.update(end=F('start') + delta) + # qs = Experiment.objects.annotate( + # delta=ExpressionWrapper(F('end') - F('start'), output_field=models.DurationField()) + # ) + # for e in qs: + # self.assertEqual(e.delta, delta) + + def test_duration_with_datetime(self): + # Exclude e1 which has very high precision so we can test this on all + # backends regardless of whether or not it supports + # microsecond_precision. + over_estimate = Experiment.objects.exclude(name='e1').filter( + completed__gt=self.stime + F('estimated_time'), + ).order_by('name') + self.assertQuerysetEqual(over_estimate, ['e3', 'e4', 'e5'], lambda e: e.name) + + #def test_duration_with_datetime_microseconds(self): + # delta = datetime.timedelta(microseconds=8999999999999999) + # qs = Experiment.objects.annotate(dt=ExpressionWrapper( + # F('start') + delta, + # output_field=models.DateTimeField(), + # )) + # for e in qs: + # self.assertEqual(e.dt, e.start + delta) + + def test_date_minus_duration(self): + more_than_4_days = Experiment.objects.filter( + assigned__lt=F('completed') - Value(datetime.timedelta(days=4), output_field=models.DurationField()) + ) + self.assertQuerysetEqual(more_than_4_days, ['e3', 'e4', 'e5'], lambda e: e.name) + + def test_negative_timedelta_update(self): + # subtract 30 seconds, 30 minutes, 2 hours and 2 days + experiments = Experiment.objects.filter(name='e0').annotate( + start_sub_seconds=F('start') + datetime.timedelta(seconds=-30), + ).annotate( + start_sub_minutes=F('start_sub_seconds') + datetime.timedelta(minutes=-30), + ).annotate( + start_sub_hours=F('start_sub_minutes') + datetime.timedelta(hours=-2), + ).annotate( + new_start=F('start_sub_hours') + datetime.timedelta(days=-2), + ) + expected_start = datetime.datetime(2010, 6, 23, 9, 45, 0) + # subtract 30 microseconds + experiments = experiments.annotate(new_start=F('new_start') + datetime.timedelta(microseconds=-30)) + expected_start += datetime.timedelta(microseconds=+746970) + experiments.update(start=F('new_start')) + e0 = Experiment.objects.get(name='e0') + self.assertEqual(e0.start, expected_start) + + +class ValueTests(TestCase): + def test_update_TimeField_using_Value(self): + Time.objects.create() + Time.objects.update(time=Value(datetime.time(1), output_field=TimeField())) + self.assertEqual(Time.objects.get().time, datetime.time(1)) + + def test_update_UUIDField_using_Value(self): + UUID.objects.create() + UUID.objects.update(uuid=Value(uuid.UUID('12345678901234567890123456789012'), output_field=UUIDField())) + self.assertEqual(UUID.objects.get().uuid, uuid.UUID('12345678901234567890123456789012')) + + def test_deconstruct(self): + value = Value('name') + path, args, kwargs = value.deconstruct() + self.assertEqual(path, 'django.db.models.expressions.Value') + self.assertEqual(args, (value.value,)) + self.assertEqual(kwargs, {}) + + def test_deconstruct_output_field(self): + value = Value('name', output_field=CharField()) + path, args, kwargs = value.deconstruct() + self.assertEqual(path, 'django.db.models.expressions.Value') + self.assertEqual(args, (value.value,)) + self.assertEqual(len(kwargs), 1) + self.assertEqual(kwargs['output_field'].deconstruct(), CharField().deconstruct()) + + def test_equal(self): + value = Value('name') + same_value = Value('name') + other_value = Value('username') + self.assertEqual(value, same_value) + self.assertNotEqual(value, other_value) + + def test_hash(self): + d = {Value('name'): 'Bob'} + self.assertIn(Value('name'), d) + self.assertEqual(d[Value('name')], 'Bob') + + def test_equal_output_field(self): + value = Value('name', output_field=CharField()) + same_value = Value('name', output_field=CharField()) + other_value = Value('name', output_field=TimeField()) + no_output_field = Value('name') + self.assertEqual(value, same_value) + self.assertNotEqual(value, other_value) + self.assertNotEqual(value, no_output_field) + + def test_raise_empty_expressionlist(self): + msg = 'ExpressionList requires at least one expression' + with self.assertRaisesMessage(ValueError, msg): + ExpressionList() + + +class FieldTransformTests(TestCase): + + @classmethod + def setUpTestData(cls): + cls.sday = sday = datetime.date(2010, 6, 25) + cls.stime = stime = datetime.datetime(2010, 6, 25, 12, 15, 30, 747000) + cls.ex1 = Experiment.objects.create( + name='Experiment 1', + assigned=sday, + completed=sday + datetime.timedelta(2), + estimated_time=datetime.timedelta(2), + start=stime, + end=stime + datetime.timedelta(2), + ) + + def test_month_aggregation(self): + self.assertEqual( + Experiment.objects.aggregate(month_count=Count('assigned__month')), + {'month_count': 1} + ) + + def test_transform_in_values(self): + self.assertQuerysetEqual( + Experiment.objects.values('assigned__month'), + ["{'assigned__month': 6}"] + ) + + def test_multiple_transforms_in_values(self): + self.assertQuerysetEqual( + Experiment.objects.values('end__date__month'), + ["{'end__date__month': 6}"] + ) + + +class ReprTests(TestCase): + + def test_expressions(self): + self.assertEqual( + repr(Case(When(a=1))), + " THEN Value(None), ELSE Value(None)>" + ) + self.assertEqual( + repr(When(Q(age__gte=18), then=Value('legal'))), + " THEN Value(legal)>" + ) + self.assertEqual(repr(Col('alias', 'field')), "Col(alias, field)") + self.assertEqual(repr(F('published')), "F(published)") + self.assertEqual(repr(F('cost') + F('tax')), "") + self.assertEqual( + repr(ExpressionWrapper(F('cost') + F('tax'), models.IntegerField())), + "ExpressionWrapper(F(cost) + F(tax))" + ) + self.assertEqual(repr(Func('published', function='TO_CHAR')), "Func(F(published), function=TO_CHAR)") + self.assertEqual(repr(OrderBy(Value(1))), 'OrderBy(Value(1), descending=False)') + self.assertEqual(repr(Random()), "Random()") + self.assertEqual(repr(RawSQL('table.col', [])), "RawSQL(table.col, [])") + self.assertEqual(repr(Ref('sum_cost', Sum('cost'))), "Ref(sum_cost, Sum(F(cost)))") + self.assertEqual(repr(Value(1)), "Value(1)") + self.assertEqual( + repr(ExpressionList(F('col'), F('anothercol'))), + 'ExpressionList(F(col), F(anothercol))' + ) + self.assertEqual( + repr(ExpressionList(OrderBy(F('col'), descending=False))), + 'ExpressionList(OrderBy(F(col), descending=False))' + ) + + def test_functions(self): + self.assertEqual(repr(Coalesce('a', 'b')), "Coalesce(F(a), F(b))") + self.assertEqual(repr(Concat('a', 'b')), "Concat(ConcatPair(F(a), F(b)))") + self.assertEqual(repr(Length('a')), "Length(F(a))") + self.assertEqual(repr(Lower('a')), "Lower(F(a))") + self.assertEqual(repr(Substr('a', 1, 3)), "Substr(F(a), Value(1), Value(3))") + self.assertEqual(repr(Upper('a')), "Upper(F(a))") + + def test_aggregates(self): + self.assertEqual(repr(Avg('a')), "Avg(F(a))") + self.assertEqual(repr(Count('a')), "Count(F(a), distinct=False)") + self.assertEqual(repr(Count('*')), "Count('*', distinct=False)") + self.assertEqual(repr(Max('a')), "Max(F(a))") + self.assertEqual(repr(Min('a')), "Min(F(a))") + self.assertEqual(repr(StdDev('a')), "StdDev(F(a), sample=False)") + self.assertEqual(repr(Sum('a')), "Sum(F(a))") + self.assertEqual(repr(Variance('a', sample=True)), "Variance(F(a), sample=True)") + + def test_filtered_aggregates(self): + filter = Q(a=1) + self.assertEqual(repr(Avg('a', filter=filter)), "Avg(F(a), filter=(AND: ('a', 1)))") + self.assertEqual(repr(Count('a', filter=filter)), "Count(F(a), distinct=False, filter=(AND: ('a', 1)))") + self.assertEqual(repr(Max('a', filter=filter)), "Max(F(a), filter=(AND: ('a', 1)))") + self.assertEqual(repr(Min('a', filter=filter)), "Min(F(a), filter=(AND: ('a', 1)))") + self.assertEqual(repr(StdDev('a', filter=filter)), "StdDev(F(a), filter=(AND: ('a', 1)), sample=False)") + self.assertEqual(repr(Sum('a', filter=filter)), "Sum(F(a), filter=(AND: ('a', 1)))") + self.assertEqual( + repr(Variance('a', sample=True, filter=filter)), + "Variance(F(a), filter=(AND: ('a', 1)), sample=True)" + ) + + +class CombinableTests(SimpleTestCase): + bitwise_msg = 'Use .bitand() and .bitor() for bitwise logical operations.' + + def test_negation(self): + c = Combinable() + self.assertEqual(-c, c * -1) + + def test_and(self): + with self.assertRaisesMessage(NotImplementedError, self.bitwise_msg): + Combinable() & Combinable() + + def test_or(self): + with self.assertRaisesMessage(NotImplementedError, self.bitwise_msg): + Combinable() | Combinable() + + def test_reversed_and(self): + with self.assertRaisesMessage(NotImplementedError, self.bitwise_msg): + object() & Combinable() + + def test_reversed_or(self): + with self.assertRaisesMessage(NotImplementedError, self.bitwise_msg): + object() | Combinable() diff --git a/tests/field_deconstruction/__init__.py b/tests/field_deconstruction/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/field_deconstruction/tests.py b/tests/field_deconstruction/tests.py new file mode 100644 index 00000000..1a6385db --- /dev/null +++ b/tests/field_deconstruction/tests.py @@ -0,0 +1,396 @@ +from __future__ import unicode_literals + +from django.apps import apps +from django.db import models +from django.test import SimpleTestCase, override_settings +from django.test.utils import isolate_lru_cache +from django.utils import six + + +class FieldDeconstructionTests(SimpleTestCase): + """ + Tests the deconstruct() method on all core fields. + """ + + def test_name(self): + """ + Tests the outputting of the correct name if assigned one. + """ + # First try using a "normal" field + field = models.CharField(max_length=65) + name, path, args, kwargs = field.deconstruct() + self.assertIsNone(name) + field.set_attributes_from_name("is_awesome_test") + name, path, args, kwargs = field.deconstruct() + self.assertEqual(name, "is_awesome_test") + self.assertIsInstance(name, six.text_type) + # Now try with a ForeignKey + field = models.ForeignKey("some_fake.ModelName", models.CASCADE) + name, path, args, kwargs = field.deconstruct() + self.assertIsNone(name) + field.set_attributes_from_name("author") + name, path, args, kwargs = field.deconstruct() + self.assertEqual(name, "author") + + def test_auto_field(self): + field = models.AutoField(primary_key=True) + field.set_attributes_from_name("id") + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.AutoField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"primary_key": True}) + + def test_big_integer_field(self): + field = models.BigIntegerField() + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.BigIntegerField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {}) + + def test_boolean_field(self): + field = models.BooleanField() + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.BooleanField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {}) + field = models.BooleanField(default=True) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.BooleanField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"default": True}) + + def test_char_field(self): + field = models.CharField(max_length=65) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.CharField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"max_length": 65}) + field = models.CharField(max_length=65, null=True, blank=True) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.CharField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"max_length": 65, "null": True, "blank": True}) + + def test_char_field_choices(self): + field = models.CharField(max_length=1, choices=(("A", "One"), ("B", "Two"))) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.CharField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"choices": [("A", "One"), ("B", "Two")], "max_length": 1}) + + def test_csi_field(self): + field = models.CommaSeparatedIntegerField(max_length=100) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.CommaSeparatedIntegerField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"max_length": 100}) + + def test_date_field(self): + field = models.DateField() + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.DateField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {}) + field = models.DateField(auto_now=True) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.DateField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"auto_now": True}) + + def test_datetime_field(self): + field = models.DateTimeField() + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.DateTimeField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {}) + field = models.DateTimeField(auto_now_add=True) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.DateTimeField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"auto_now_add": True}) + # Bug #21785 + field = models.DateTimeField(auto_now=True, auto_now_add=True) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.DateTimeField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"auto_now_add": True, "auto_now": True}) + + def test_decimal_field(self): + field = models.DecimalField(max_digits=5, decimal_places=2) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.DecimalField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"max_digits": 5, "decimal_places": 2}) + + def test_decimal_field_0_decimal_places(self): + """ + A DecimalField with decimal_places=0 should work (#22272). + """ + field = models.DecimalField(max_digits=5, decimal_places=0) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.DecimalField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"max_digits": 5, "decimal_places": 0}) + + def test_email_field(self): + field = models.EmailField() + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.EmailField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"max_length": 254}) + field = models.EmailField(max_length=255) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.EmailField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"max_length": 255}) + + def test_file_field(self): + field = models.FileField(upload_to="foo/bar") + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.FileField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"upload_to": "foo/bar"}) + # Test max_length + field = models.FileField(upload_to="foo/bar", max_length=200) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.FileField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"upload_to": "foo/bar", "max_length": 200}) + + def test_file_path_field(self): + field = models.FilePathField(match=r".*\.txt$") + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.FilePathField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"match": r".*\.txt$"}) + field = models.FilePathField(recursive=True, allow_folders=True, max_length=123) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.FilePathField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"recursive": True, "allow_folders": True, "max_length": 123}) + + def test_float_field(self): + field = models.FloatField() + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.FloatField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {}) + + def test_foreign_key(self): + # Test basic pointing + from django.contrib.auth.models import Permission + field = models.ForeignKey("auth.Permission", models.CASCADE) + field.remote_field.model = Permission + field.remote_field.field_name = "id" + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.ForeignKey") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"to": "auth.Permission", "on_delete": models.CASCADE}) + self.assertFalse(hasattr(kwargs['to'], "setting_name")) + # Test swap detection for swappable model + field = models.ForeignKey("auth.User", models.CASCADE) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.ForeignKey") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"to": "auth.User", "on_delete": models.CASCADE}) + self.assertEqual(kwargs['to'].setting_name, "AUTH_USER_MODEL") + # Test nonexistent (for now) model + field = models.ForeignKey("something.Else", models.CASCADE) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.ForeignKey") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"to": "something.Else", "on_delete": models.CASCADE}) + # Test on_delete + field = models.ForeignKey("auth.User", models.SET_NULL) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.ForeignKey") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"to": "auth.User", "on_delete": models.SET_NULL}) + # Test to_field preservation + field = models.ForeignKey("auth.Permission", models.CASCADE, to_field="foobar") + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.ForeignKey") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"to": "auth.Permission", "to_field": "foobar", "on_delete": models.CASCADE}) + # Test related_name preservation + field = models.ForeignKey("auth.Permission", models.CASCADE, related_name="foobar") + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.ForeignKey") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"to": "auth.Permission", "related_name": "foobar", "on_delete": models.CASCADE}) + + @override_settings(AUTH_USER_MODEL="auth.Permission") + def test_foreign_key_swapped(self): + with isolate_lru_cache(apps.get_swappable_settings_name): + # It doesn't matter that we swapped out user for permission; + # there's no validation. We just want to check the setting stuff works. + field = models.ForeignKey("auth.Permission", models.CASCADE) + name, path, args, kwargs = field.deconstruct() + + self.assertEqual(path, "django.db.models.ForeignKey") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"to": "auth.Permission", "on_delete": models.CASCADE}) + self.assertEqual(kwargs['to'].setting_name, "AUTH_USER_MODEL") + + def test_image_field(self): + field = models.ImageField(upload_to="foo/barness", width_field="width", height_field="height") + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.ImageField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"upload_to": "foo/barness", "width_field": "width", "height_field": "height"}) + + def test_integer_field(self): + field = models.IntegerField() + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.IntegerField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {}) + + def test_ip_address_field(self): + field = models.IPAddressField() + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.IPAddressField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {}) + + def test_generic_ip_address_field(self): + field = models.GenericIPAddressField() + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.GenericIPAddressField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {}) + field = models.GenericIPAddressField(protocol="IPv6") + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.GenericIPAddressField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"protocol": "IPv6"}) + + def test_many_to_many_field(self): + # Test normal + field = models.ManyToManyField("auth.Permission") + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.ManyToManyField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"to": "auth.Permission"}) + self.assertFalse(hasattr(kwargs['to'], "setting_name")) + # Test swappable + field = models.ManyToManyField("auth.User") + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.ManyToManyField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"to": "auth.User"}) + self.assertEqual(kwargs['to'].setting_name, "AUTH_USER_MODEL") + # Test through + field = models.ManyToManyField("auth.Permission", through="auth.Group") + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.ManyToManyField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"to": "auth.Permission", "through": "auth.Group"}) + # Test custom db_table + field = models.ManyToManyField("auth.Permission", db_table="custom_table") + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.ManyToManyField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"to": "auth.Permission", "db_table": "custom_table"}) + # Test related_name + field = models.ManyToManyField("auth.Permission", related_name="custom_table") + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.ManyToManyField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"to": "auth.Permission", "related_name": "custom_table"}) + + @override_settings(AUTH_USER_MODEL="auth.Permission") + def test_many_to_many_field_swapped(self): + with isolate_lru_cache(apps.get_swappable_settings_name): + # It doesn't matter that we swapped out user for permission; + # there's no validation. We just want to check the setting stuff works. + field = models.ManyToManyField("auth.Permission") + name, path, args, kwargs = field.deconstruct() + + self.assertEqual(path, "django.db.models.ManyToManyField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"to": "auth.Permission"}) + self.assertEqual(kwargs['to'].setting_name, "AUTH_USER_MODEL") + + def test_null_boolean_field(self): + field = models.NullBooleanField() + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.NullBooleanField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {}) + + def test_positive_integer_field(self): + field = models.PositiveIntegerField() + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.PositiveIntegerField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {}) + + def test_positive_small_integer_field(self): + field = models.PositiveSmallIntegerField() + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.PositiveSmallIntegerField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {}) + + def test_slug_field(self): + field = models.SlugField() + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.SlugField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {}) + field = models.SlugField(db_index=False, max_length=231) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.SlugField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"db_index": False, "max_length": 231}) + + def test_small_integer_field(self): + field = models.SmallIntegerField() + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.SmallIntegerField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {}) + + def test_text_field(self): + field = models.TextField() + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.TextField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {}) + + def test_time_field(self): + field = models.TimeField() + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.TimeField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {}) + + field = models.TimeField(auto_now=True) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(args, []) + self.assertEqual(kwargs, {'auto_now': True}) + + field = models.TimeField(auto_now_add=True) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(args, []) + self.assertEqual(kwargs, {'auto_now_add': True}) + + def test_url_field(self): + field = models.URLField() + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.URLField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {}) + field = models.URLField(max_length=231) + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.URLField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {"max_length": 231}) + + def test_binary_field(self): + field = models.BinaryField() + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django.db.models.BinaryField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {}) diff --git a/tests/field_defaults/__init__.py b/tests/field_defaults/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/field_defaults/models.py b/tests/field_defaults/models.py new file mode 100644 index 00000000..4f062320 --- /dev/null +++ b/tests/field_defaults/models.py @@ -0,0 +1,25 @@ +# coding: utf-8 +""" +Callable defaults + +You can pass callable objects as the ``default`` parameter to a field. When +the object is created without an explicit value passed in, Django will call +the method to determine the default value. + +This example uses ``datetime.datetime.now`` as the default for the ``pub_date`` +field. +""" + +from datetime import datetime + +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +@python_2_unicode_compatible +class Article(models.Model): + headline = models.CharField(max_length=100, default='Default headline') + pub_date = models.DateTimeField(default=datetime.now) + + def __str__(self): + return self.headline diff --git a/tests/field_defaults/tests.py b/tests/field_defaults/tests.py new file mode 100644 index 00000000..031fd75f --- /dev/null +++ b/tests/field_defaults/tests.py @@ -0,0 +1,17 @@ +from datetime import datetime + +from django.test import TestCase +from django.utils import six + +from .models import Article + + +class DefaultTests(TestCase): + def test_field_defaults(self): + a = Article() + now = datetime.now() + a.save() + + self.assertIsInstance(a.id, six.integer_types) + self.assertEqual(a.headline, "Default headline") + self.assertLess((now - a.pub_date).seconds, 5) diff --git a/tests/field_subclassing/__init__.py b/tests/field_subclassing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/field_subclassing/fields.py b/tests/field_subclassing/fields.py new file mode 100644 index 00000000..c2e4b50c --- /dev/null +++ b/tests/field_subclassing/fields.py @@ -0,0 +1,8 @@ +from __future__ import unicode_literals + +from django.db import models + + +class CustomTypedField(models.TextField): + def db_type(self, connection): + return 'custom_field' diff --git a/tests/field_subclassing/tests.py b/tests/field_subclassing/tests.py new file mode 100644 index 00000000..d291276c --- /dev/null +++ b/tests/field_subclassing/tests.py @@ -0,0 +1,13 @@ +from __future__ import unicode_literals + +from django.db import connection +from django.test import SimpleTestCase + +from .fields import CustomTypedField + + +class TestDbType(SimpleTestCase): + + def test_db_parameters_respects_db_type(self): + f = CustomTypedField() + self.assertEqual(f.db_parameters(connection)['type'], 'custom_field') diff --git a/tests/force_insert_update/__init__.py b/tests/force_insert_update/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/force_insert_update/models.py b/tests/force_insert_update/models.py new file mode 100644 index 00000000..a98eadb4 --- /dev/null +++ b/tests/force_insert_update/models.py @@ -0,0 +1,28 @@ +""" +Tests for forcing insert and update queries (instead of Django's normal +automatic behavior). +""" +from django.db import models + + +class Counter(models.Model): + name = models.CharField(max_length=10) + value = models.IntegerField() + + +class InheritedCounter(Counter): + tag = models.CharField(max_length=10) + + +class ProxyCounter(Counter): + class Meta: + proxy = True + + +class SubCounter(Counter): + pass + + +class WithCustomPK(models.Model): + name = models.IntegerField(primary_key=True) + value = models.IntegerField() diff --git a/tests/force_insert_update/tests.py b/tests/force_insert_update/tests.py new file mode 100644 index 00000000..ae8b771e --- /dev/null +++ b/tests/force_insert_update/tests.py @@ -0,0 +1,69 @@ +from __future__ import unicode_literals + +from django.db import DatabaseError, IntegrityError, transaction +from django.test import TestCase + +from .models import ( + Counter, InheritedCounter, ProxyCounter, SubCounter, WithCustomPK, +) + + +class ForceTests(TestCase): + def test_force_update(self): + c = Counter.objects.create(name="one", value=1) + + # The normal case + c.value = 2 + c.save() + # Same thing, via an update + c.value = 3 + c.save(force_update=True) + + # Won't work because force_update and force_insert are mutually + # exclusive + c.value = 4 + with self.assertRaises(ValueError): + c.save(force_insert=True, force_update=True) + + # Try to update something that doesn't have a primary key in the first + # place. + c1 = Counter(name="two", value=2) + with self.assertRaises(ValueError): + with transaction.atomic(): + c1.save(force_update=True) + c1.save(force_insert=True) + + # Won't work because we can't insert a pk of the same value. + c.value = 5 + with self.assertRaises(IntegrityError): + with transaction.atomic(): + c.save(force_insert=True) + + # Trying to update should still fail, even with manual primary keys, if + # the data isn't in the database already. + obj = WithCustomPK(name=1, value=1) + with self.assertRaises(DatabaseError): + with transaction.atomic(): + obj.save(force_update=True) + + +class InheritanceTests(TestCase): + def test_force_update_on_inherited_model(self): + a = InheritedCounter(name="count", value=1, tag="spam") + a.save() + a.save(force_update=True) + + def test_force_update_on_proxy_model(self): + a = ProxyCounter(name="count", value=1) + a.save() + a.save(force_update=True) + + def test_force_update_on_inherited_model_without_fields(self): + ''' + Issue 13864: force_update fails on subclassed models, if they don't + specify custom fields. + ''' + a = SubCounter(name="count", value=1) + a.save() + a.value = 2 + a.save(force_update=True) diff --git a/tests/generic_relations_regress/__init__.py b/tests/generic_relations_regress/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/generic_relations_regress/models.py b/tests/generic_relations_regress/models.py new file mode 100644 index 00000000..eb4f645d --- /dev/null +++ b/tests/generic_relations_regress/models.py @@ -0,0 +1,218 @@ +from django.contrib.contenttypes.fields import ( + GenericForeignKey, GenericRelation, +) +from django.contrib.contenttypes.models import ContentType +from django.db import models +from django.db.models.deletion import ProtectedError +from django.utils.encoding import python_2_unicode_compatible + +__all__ = ('Link', 'Place', 'Restaurant', 'Person', 'Address', + 'CharLink', 'TextLink', 'OddRelation1', 'OddRelation2', + 'Contact', 'Organization', 'Note', 'Company') + + +@python_2_unicode_compatible +class Link(models.Model): + content_type = models.ForeignKey(ContentType, models.CASCADE) + object_id = models.PositiveIntegerField() + content_object = GenericForeignKey() + + def __str__(self): + return "Link to %s id=%s" % (self.content_type, self.object_id) + + +@python_2_unicode_compatible +class Place(models.Model): + name = models.CharField(max_length=100) + links = GenericRelation(Link) + + def __str__(self): + return "Place: %s" % self.name + + +@python_2_unicode_compatible +class Restaurant(Place): + def __str__(self): + return "Restaurant: %s" % self.name + + +@python_2_unicode_compatible +class Address(models.Model): + street = models.CharField(max_length=80) + city = models.CharField(max_length=50) + state = models.CharField(max_length=2) + zipcode = models.CharField(max_length=5) + content_type = models.ForeignKey(ContentType, models.CASCADE) + object_id = models.PositiveIntegerField() + content_object = GenericForeignKey() + + def __str__(self): + return '%s %s, %s %s' % (self.street, self.city, self.state, self.zipcode) + + +@python_2_unicode_compatible +class Person(models.Model): + account = models.IntegerField(primary_key=True) + name = models.CharField(max_length=128) + addresses = GenericRelation(Address) + + def __str__(self): + return self.name + + +class CharLink(models.Model): + content_type = models.ForeignKey(ContentType, models.CASCADE) + object_id = models.CharField(max_length=100) + content_object = GenericForeignKey() + + +class TextLink(models.Model): + content_type = models.ForeignKey(ContentType, models.CASCADE) + object_id = models.TextField() + content_object = GenericForeignKey() + + +class OddRelation1(models.Model): + name = models.CharField(max_length=100) + clinks = GenericRelation(CharLink) + + +class OddRelation2(models.Model): + name = models.CharField(max_length=100) + tlinks = GenericRelation(TextLink) + + +# models for test_q_object_or: +class Note(models.Model): + content_type = models.ForeignKey(ContentType, models.CASCADE) + object_id = models.PositiveIntegerField() + content_object = GenericForeignKey() + note = models.TextField() + + +class Contact(models.Model): + notes = GenericRelation(Note) + + +class Organization(models.Model): + name = models.CharField(max_length=255) + contacts = models.ManyToManyField(Contact, related_name='organizations') + + +@python_2_unicode_compatible +class Company(models.Model): + name = models.CharField(max_length=100) + links = GenericRelation(Link) + + def __str__(self): + return "Company: %s" % self.name + + +# For testing #13085 fix, we also use Note model defined above +class Developer(models.Model): + name = models.CharField(max_length=15) + + +@python_2_unicode_compatible +class Team(models.Model): + name = models.CharField(max_length=15) + members = models.ManyToManyField(Developer) + + def __str__(self): + return "%s team" % self.name + + def __len__(self): + return self.members.count() + + +class Guild(models.Model): + name = models.CharField(max_length=15) + members = models.ManyToManyField(Developer) + + def __nonzero__(self): + + return self.members.count() + + +class Tag(models.Model): + content_type = models.ForeignKey(ContentType, models.CASCADE, related_name='g_r_r_tags') + object_id = models.CharField(max_length=15) + content_object = GenericForeignKey() + label = models.CharField(max_length=15) + + +class Board(models.Model): + name = models.CharField(primary_key=True, max_length=15) + + +class SpecialGenericRelation(GenericRelation): + def __init__(self, *args, **kwargs): + super(SpecialGenericRelation, self).__init__(*args, **kwargs) + self.editable = True + self.save_form_data_calls = 0 + + def save_form_data(self, *args, **kwargs): + self.save_form_data_calls += 1 + + +class HasLinks(models.Model): + links = SpecialGenericRelation(Link) + + class Meta: + abstract = True + + +class HasLinkThing(HasLinks): + pass + + +class A(models.Model): + flag = models.NullBooleanField() + content_type = models.ForeignKey(ContentType, models.CASCADE) + object_id = models.PositiveIntegerField() + content_object = GenericForeignKey('content_type', 'object_id') + + +class B(models.Model): + a = GenericRelation(A) + + class Meta: + ordering = ('id',) + + +class C(models.Model): + b = models.ForeignKey(B, models.CASCADE) + + class Meta: + ordering = ('id',) + + +class D(models.Model): + b = models.ForeignKey(B, models.SET_NULL, null=True) + + class Meta: + ordering = ('id',) + + +# Ticket #22998 + +class Node(models.Model): + content_type = models.ForeignKey(ContentType, models.CASCADE) + object_id = models.PositiveIntegerField() + content = GenericForeignKey('content_type', 'object_id') + + +class Content(models.Model): + nodes = GenericRelation(Node) + related_obj = models.ForeignKey('Related', models.CASCADE) + + +class Related(models.Model): + pass + + +def prevent_deletes(sender, instance, **kwargs): + raise ProtectedError("Not allowed to delete.", [instance]) + + +models.signals.pre_delete.connect(prevent_deletes, sender=Node) diff --git a/tests/generic_relations_regress/tests.py b/tests/generic_relations_regress/tests.py new file mode 100644 index 00000000..d3986b69 --- /dev/null +++ b/tests/generic_relations_regress/tests.py @@ -0,0 +1,246 @@ +from django.db.models import Q, Sum +from django.db.models.deletion import ProtectedError +from django.db.utils import IntegrityError +from django.forms.models import modelform_factory +from django.test import TestCase, skipIfDBFeature + +from .models import ( + A, Address, B, Board, C, CharLink, Company, Contact, Content, D, Developer, + Guild, HasLinkThing, Link, Node, Note, OddRelation1, OddRelation2, + Organization, Person, Place, Related, Restaurant, Tag, Team, TextLink, +) + + +class GenericRelationTests(TestCase): + + def test_inherited_models_content_type(self): + """ + GenericRelations on inherited classes use the correct content type. + """ + p = Place.objects.create(name="South Park") + r = Restaurant.objects.create(name="Chubby's") + l1 = Link.objects.create(content_object=p) + l2 = Link.objects.create(content_object=r) + self.assertEqual(list(p.links.all()), [l1]) + self.assertEqual(list(r.links.all()), [l2]) + + def test_reverse_relation_pk(self): + """ + The correct column name is used for the primary key on the + originating model of a query. See #12664. + """ + p = Person.objects.create(account=23, name='Chef') + Address.objects.create(street='123 Anywhere Place', + city='Conifer', state='CO', + zipcode='80433', content_object=p) + + qs = Person.objects.filter(addresses__zipcode='80433') + self.assertEqual(1, qs.count()) + self.assertEqual('Chef', qs[0].name) + + def test_charlink_delete(self): + oddrel = OddRelation1.objects.create(name='clink') + CharLink.objects.create(content_object=oddrel) + oddrel.delete() + + def test_textlink_delete(self): + oddrel = OddRelation2.objects.create(name='tlink') + TextLink.objects.create(content_object=oddrel) + oddrel.delete() + + def test_q_object_or(self): + """ + SQL query parameters for generic relations are properly + grouped when OR is used (#11535). + + In this bug the first query (below) works while the second, with the + query parameters the same but in reverse order, does not. + + The issue is that the generic relation conditions do not get properly + grouped in parentheses. + """ + note_contact = Contact.objects.create() + org_contact = Contact.objects.create() + Note.objects.create(note='note', content_object=note_contact) + org = Organization.objects.create(name='org name') + org.contacts.add(org_contact) + # search with a non-matching note and a matching org name + qs = Contact.objects.filter(Q(notes__note__icontains=r'other note') | + Q(organizations__name__icontains=r'org name')) + self.assertIn(org_contact, qs) + # search again, with the same query parameters, in reverse order + qs = Contact.objects.filter( + Q(organizations__name__icontains=r'org name') | + Q(notes__note__icontains=r'other note')) + self.assertIn(org_contact, qs) + + def test_join_reuse(self): + qs = Person.objects.filter( + addresses__street='foo' + ).filter( + addresses__street='bar' + ) + self.assertEqual(str(qs.query).count('JOIN'), 2) + + def test_generic_relation_ordering(self): + """ + Ordering over a generic relation does not include extraneous + duplicate results, nor excludes rows not participating in the relation. + """ + p1 = Place.objects.create(name="South Park") + p2 = Place.objects.create(name="The City") + c = Company.objects.create(name="Chubby's Intl.") + Link.objects.create(content_object=p1) + Link.objects.create(content_object=c) + + places = list(Place.objects.order_by('links__id')) + + def count_places(place): + return len([p for p in places if p.id == place.id]) + + self.assertEqual(len(places), 2) + self.assertEqual(count_places(p1), 1) + self.assertEqual(count_places(p2), 1) + + def test_target_model_is_unsaved(self): + """Test related to #13085""" + # Fails with another, ORM-level error + dev1 = Developer(name='Joe') + note = Note(note='Deserves promotion', content_object=dev1) + with self.assertRaises(IntegrityError): + note.save() + + def test_target_model_len_zero(self): + """ + Saving a model with a GenericForeignKey to a model instance whose + __len__ method returns 0 (Team.__len__() here) shouldn't fail (#13085). + """ + team1 = Team.objects.create(name='Backend devs') + note = Note(note='Deserve a bonus', content_object=team1) + note.save() + + def test_target_model_nonzero_false(self): + """Test related to #13085""" + # __nonzero__() returns False -- This actually doesn't currently fail. + # This test validates that + g1 = Guild.objects.create(name='First guild') + note = Note(note='Note for guild', content_object=g1) + note.save() + + @skipIfDBFeature('interprets_empty_strings_as_nulls') + def test_gfk_to_model_with_empty_pk(self): + """Test related to #13085""" + # Saving model with GenericForeignKey to model instance with an + # empty CharField PK + b1 = Board.objects.create(name='') + tag = Tag(label='VP', content_object=b1) + tag.save() + + def test_ticket_20378(self): + # Create a couple of extra HasLinkThing so that the autopk value + # isn't the same for Link and HasLinkThing. + hs1 = HasLinkThing.objects.create() + hs2 = HasLinkThing.objects.create() + hs3 = HasLinkThing.objects.create() + hs4 = HasLinkThing.objects.create() + l1 = Link.objects.create(content_object=hs3) + l2 = Link.objects.create(content_object=hs4) + self.assertSequenceEqual(HasLinkThing.objects.filter(links=l1), [hs3]) + self.assertSequenceEqual(HasLinkThing.objects.filter(links=l2), [hs4]) + self.assertSequenceEqual(HasLinkThing.objects.exclude(links=l2), [hs1, hs2, hs3]) + self.assertSequenceEqual(HasLinkThing.objects.exclude(links=l1), [hs1, hs2, hs4]) + + def test_ticket_20564(self): + b1 = B.objects.create() + b2 = B.objects.create() + b3 = B.objects.create() + c1 = C.objects.create(b=b1) + c2 = C.objects.create(b=b2) + c3 = C.objects.create(b=b3) + A.objects.create(flag=None, content_object=b1) + A.objects.create(flag=True, content_object=b2) + self.assertSequenceEqual(C.objects.filter(b__a__flag=None), [c1, c3]) + self.assertSequenceEqual(C.objects.exclude(b__a__flag=None), [c2]) + + def test_ticket_20564_nullable_fk(self): + b1 = B.objects.create() + b2 = B.objects.create() + b3 = B.objects.create() + d1 = D.objects.create(b=b1) + d2 = D.objects.create(b=b2) + d3 = D.objects.create(b=b3) + d4 = D.objects.create() + A.objects.create(flag=None, content_object=b1) + A.objects.create(flag=True, content_object=b1) + A.objects.create(flag=True, content_object=b2) + self.assertSequenceEqual(D.objects.exclude(b__a__flag=None), [d2]) + self.assertSequenceEqual(D.objects.filter(b__a__flag=None), [d1, d3, d4]) + self.assertSequenceEqual(B.objects.filter(a__flag=None), [b1, b3]) + self.assertSequenceEqual(B.objects.exclude(a__flag=None), [b2]) + + def test_extra_join_condition(self): + # A crude check that content_type_id is taken in account in the + # join/subquery condition. + self.assertIn("content_type_id", str(B.objects.exclude(a__flag=None).query).lower()) + # No need for any joins - the join from inner query can be trimmed in + # this case (but not in the above case as no a objects at all for given + # B would then fail). + self.assertNotIn(" join ", str(B.objects.exclude(a__flag=True).query).lower()) + self.assertIn("content_type_id", str(B.objects.exclude(a__flag=True).query).lower()) + + def test_annotate(self): + hs1 = HasLinkThing.objects.create() + hs2 = HasLinkThing.objects.create() + HasLinkThing.objects.create() + b = Board.objects.create(name=str(hs1.pk)) + Link.objects.create(content_object=hs2) + link = Link.objects.create(content_object=hs1) + Link.objects.create(content_object=b) + qs = HasLinkThing.objects.annotate(Sum('links')).filter(pk=hs1.pk) + # If content_type restriction isn't in the query's join condition, + # then wrong results are produced here as the link to b will also match + # (b and hs1 have equal pks). + self.assertEqual(qs.count(), 1) + self.assertEqual(qs[0].links__sum, link.id) + link.delete() + # Now if we don't have proper left join, we will not produce any + # results at all here. + # clear cached results + qs = qs.all() + self.assertEqual(qs.count(), 1) + # Note - 0 here would be a nicer result... + self.assertIs(qs[0].links__sum, None) + # Finally test that filtering works. + self.assertEqual(qs.filter(links__sum__isnull=True).count(), 1) + self.assertEqual(qs.filter(links__sum__isnull=False).count(), 0) + + def test_filter_targets_related_pk(self): + HasLinkThing.objects.create() + hs2 = HasLinkThing.objects.create() + link = Link.objects.create(content_object=hs2) + self.assertNotEqual(link.object_id, link.pk) + self.assertSequenceEqual(HasLinkThing.objects.filter(links=link.pk), [hs2]) + + def test_editable_generic_rel(self): + GenericRelationForm = modelform_factory(HasLinkThing, fields='__all__') + form = GenericRelationForm() + self.assertIn('links', form.fields) + form = GenericRelationForm({'links': None}) + self.assertTrue(form.is_valid()) + form.save() + links = HasLinkThing._meta.get_field('links') + self.assertEqual(links.save_form_data_calls, 1) + + def test_ticket_22998(self): + related = Related.objects.create() + content = Content.objects.create(related_obj=related) + Node.objects.create(content=content) + + # deleting the Related cascades to the Content cascades to the Node, + # where the pre_delete signal should fire and prevent deletion. + with self.assertRaises(ProtectedError): + related.delete() + + def test_ticket_22982(self): + place = Place.objects.create(name='My Place') + self.assertIn('GenericRelatedObjectManager', str(place.links)) diff --git a/tests/indexes/__init__.py b/tests/indexes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/indexes/models.py b/tests/indexes/models.py new file mode 100644 index 00000000..208da32c --- /dev/null +++ b/tests/indexes/models.py @@ -0,0 +1,54 @@ +from django.db import connection, models + + +class CurrentTranslation(models.ForeignObject): + """ + Creates virtual relation to the translation with model cache enabled. + """ + # Avoid validation + requires_unique_target = False + + def __init__(self, to, on_delete, from_fields, to_fields, **kwargs): + # Disable reverse relation + kwargs['related_name'] = '+' + # Set unique to enable model cache. + kwargs['unique'] = True + super().__init__(to, on_delete, from_fields, to_fields, **kwargs) + + +class ArticleTranslation(models.Model): + + article = models.ForeignKey('indexes.Article', models.CASCADE) + article_no_constraint = models.ForeignKey('indexes.Article', models.CASCADE, db_constraint=False, related_name='+') + language = models.CharField(max_length=10, unique=True) + content = models.TextField() + + +class Article(models.Model): + headline = models.CharField(max_length=100) + pub_date = models.DateTimeField() + + # Add virtual relation to the ArticleTranslation model. + translation = CurrentTranslation(ArticleTranslation, models.CASCADE, ['id'], ['article']) + + class Meta: + index_together = [ + ["headline", "pub_date"], + ] + + +# Model for index_together being used only with single list +class IndexTogetherSingleList(models.Model): + headline = models.CharField(max_length=100) + pub_date = models.DateTimeField() + + class Meta: + index_together = ["headline", "pub_date"] + + +# Indexing a TextField on Oracle or MySQL results in index creation error. +if connection.vendor == 'postgresql': + class IndexedArticle(models.Model): + headline = models.CharField(max_length=100, db_index=True) + body = models.TextField(db_index=True) + slug = models.CharField(max_length=40, unique=True) diff --git a/tests/indexes/tests.py b/tests/indexes/tests.py new file mode 100644 index 00000000..ee2cbd15 --- /dev/null +++ b/tests/indexes/tests.py @@ -0,0 +1,125 @@ +from unittest import skipUnless + +from django.db import connection +from django.db.models.deletion import CASCADE +from django.db.models.fields.related import ForeignKey +from django.test import TestCase, TransactionTestCase + +from .models import Article, ArticleTranslation, IndexTogetherSingleList + + +class SchemaIndexesTests(TestCase): + """ + Test index handling by the db.backends.schema infrastructure. + """ + + def test_index_name_hash(self): + """ + Index names should be deterministic. + """ + with connection.schema_editor() as editor: + index_name = editor._create_index_name( + table_name=Article._meta.db_table, + column_names=("c1",), + suffix="123", + ) + self.assertEqual(index_name, "indexes_article_c1_a52bd80b123") + + def test_index_name(self): + """ + Index names on the built-in database backends:: + * Are truncated as needed. + * Include all the column names. + * Include a deterministic hash. + """ + long_name = 'l%sng' % ('o' * 100) + with connection.schema_editor() as editor: + index_name = editor._create_index_name( + table_name=Article._meta.db_table, + column_names=('c1', 'c2', long_name), + suffix='ix', + ) + expected = { + 'mysql': 'indexes_article_c1_c2_looooooooooooooooooo_255179b2ix', + 'oracle': 'indexes_a_c1_c2_loo_255179b2ix', + 'postgresql': 'indexes_article_c1_c2_loooooooooooooooooo_255179b2ix', + 'sqlite': 'indexes_article_c1_c2_l%sng_255179b2ix' % ('o' * 100), + } + if connection.vendor not in expected: + self.skipTest('This test is only supported on the built-in database backends.') + self.assertEqual(index_name, expected[connection.vendor]) + + def test_index_together(self): + editor = connection.schema_editor() + index_sql = [str(statement) for statement in editor._model_indexes_sql(Article)] + self.assertEqual(len(index_sql), 1) + # Ensure the index name is properly quoted + self.assertIn( + connection.ops.quote_name( + editor._create_index_name(Article._meta.db_table, ['headline', 'pub_date'], suffix='_idx') + ), + index_sql[0] + ) + + def test_index_together_single_list(self): + # Test for using index_together with a single list (#22172) + index_sql = connection.schema_editor()._model_indexes_sql(IndexTogetherSingleList) + self.assertEqual(len(index_sql), 1) + + @skipUnless(connection.vendor == 'postgresql', "This is a postgresql-specific issue") + def test_postgresql_text_indexes(self): + """Test creation of PostgreSQL-specific text indexes (#12234)""" + from .models import IndexedArticle + index_sql = [str(statement) for statement in connection.schema_editor()._model_indexes_sql(IndexedArticle)] + self.assertEqual(len(index_sql), 5) + self.assertIn('("headline" varchar_pattern_ops)', index_sql[1]) + self.assertIn('("body" text_pattern_ops)', index_sql[3]) + # unique=True and db_index=True should only create the varchar-specific + # index (#19441). + self.assertIn('("slug" varchar_pattern_ops)', index_sql[4]) + + @skipUnless(connection.vendor == 'postgresql', "This is a postgresql-specific issue") + def test_postgresql_virtual_relation_indexes(self): + """Test indexes are not created for related objects""" + index_sql = connection.schema_editor()._model_indexes_sql(Article) + self.assertEqual(len(index_sql), 1) + + +@skipUnless(connection.vendor == 'mysql', 'MySQL tests') +class SchemaIndexesMySQLTests(TransactionTestCase): + available_apps = ['indexes'] + + def test_no_index_for_foreignkey(self): + """ + MySQL on InnoDB already creates indexes automatically for foreign keys. + (#14180). An index should be created if db_constraint=False (#26171). + """ + storage = connection.introspection.get_storage_engine( + connection.cursor(), ArticleTranslation._meta.db_table + ) + if storage != "InnoDB": + self.skip("This test only applies to the InnoDB storage engine") + index_sql = [str(statement) for statement in connection.schema_editor()._model_indexes_sql(ArticleTranslation)] + self.assertEqual(index_sql, [ + 'CREATE INDEX `indexes_articletranslation_article_no_constraint_id_d6c0806b` ' + 'ON `indexes_articletranslation` (`article_no_constraint_id`)' + ]) + + # The index also shouldn't be created if the ForeignKey is added after + # the model was created. + field_created = False + try: + with connection.schema_editor() as editor: + new_field = ForeignKey(Article, CASCADE) + new_field.set_attributes_from_name('new_foreign_key') + editor.add_field(ArticleTranslation, new_field) + field_created = True + self.assertEqual([str(statement) for statement in editor.deferred_sql], [ + 'ALTER TABLE `indexes_articletranslation` ' + 'ADD CONSTRAINT `indexes_articletrans_new_foreign_key_id_d27a9146_fk_indexes_a` ' + 'FOREIGN KEY (`new_foreign_key_id`) REFERENCES `indexes_article` (`id`)' + ]) + finally: + if field_created: + with connection.schema_editor() as editor: + editor.remove_field(ArticleTranslation, new_field) diff --git a/tests/m2m_and_m2o/__init__.py b/tests/m2m_and_m2o/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/m2m_and_m2o/models.py b/tests/m2m_and_m2o/models.py new file mode 100644 index 00000000..60f5c437 --- /dev/null +++ b/tests/m2m_and_m2o/models.py @@ -0,0 +1,31 @@ +""" +Many-to-many and many-to-one relationships to the same table + +Make sure to set ``related_name`` if you use relationships to the same table. +""" +from __future__ import unicode_literals + +from django.db import models +from django.utils import six +from django.utils.encoding import python_2_unicode_compatible + + +class User(models.Model): + username = models.CharField(max_length=20) + + +@python_2_unicode_compatible +class Issue(models.Model): + num = models.IntegerField() + cc = models.ManyToManyField(User, blank=True, related_name='test_issue_cc') + client = models.ForeignKey(User, models.CASCADE, related_name='test_issue_client') + + def __str__(self): + return six.text_type(self.num) + + class Meta: + ordering = ('num',) + + +class UnicodeReferenceModel(models.Model): + others = models.ManyToManyField("UnicodeReferenceModel") diff --git a/tests/m2m_and_m2o/tests.py b/tests/m2m_and_m2o/tests.py new file mode 100644 index 00000000..2c84a7f2 --- /dev/null +++ b/tests/m2m_and_m2o/tests.py @@ -0,0 +1,94 @@ +from django.db.models import Q +from django.test import TestCase + +from .models import Issue, UnicodeReferenceModel, User + + +class RelatedObjectTests(TestCase): + + def test_related_objects_have_name_attribute(self): + for field_name in ('test_issue_client', 'test_issue_cc'): + obj = User._meta.get_field(field_name) + self.assertEqual(field_name, obj.field.related_query_name()) + + def test_m2m_and_m2o(self): + r = User.objects.create(username="russell") + g = User.objects.create(username="gustav") + + i1 = Issue(num=1) + i1.client = r + i1.save() + + i2 = Issue(num=2) + i2.client = r + i2.save() + i2.cc.add(r) + + i3 = Issue(num=3) + i3.client = g + i3.save() + i3.cc.add(r) + + self.assertQuerysetEqual( + Issue.objects.filter(client=r.id), [ + 1, + 2, + ], + lambda i: i.num + ) + self.assertQuerysetEqual( + Issue.objects.filter(client=g.id), [ + 3, + ], + lambda i: i.num + ) + self.assertQuerysetEqual( + Issue.objects.filter(cc__id__exact=g.id), [] + ) + self.assertQuerysetEqual( + Issue.objects.filter(cc__id__exact=r.id), [ + 2, + 3, + ], + lambda i: i.num + ) + + # These queries combine results from the m2m and the m2o relationships. + # They're three ways of saying the same thing. + self.assertQuerysetEqual( + Issue.objects.filter(Q(cc__id__exact=r.id) | Q(client=r.id)), [ + 1, + 2, + 3, + ], + lambda i: i.num + ) + self.assertQuerysetEqual( + Issue.objects.filter(cc__id__exact=r.id) | Issue.objects.filter(client=r.id), [ + 1, + 2, + 3, + ], + lambda i: i.num + ) + self.assertQuerysetEqual( + Issue.objects.filter(Q(client=r.id) | Q(cc__id__exact=r.id)), [ + 1, + 2, + 3, + ], + lambda i: i.num + ) + + +class RelatedObjectUnicodeTests(TestCase): + def test_m2m_with_unicode_reference(self): + """ + Regression test for #6045: references to other models can be unicode + strings, providing they are directly convertible to ASCII. + """ + m1 = UnicodeReferenceModel.objects.create() + m2 = UnicodeReferenceModel.objects.create() + m2.others.add(m1) # used to cause an error (see ticket #6045) + m2.save() + list(m2.others.all()) # Force retrieval. diff --git a/tests/m2m_intermediary/__init__.py b/tests/m2m_intermediary/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/m2m_intermediary/models.py b/tests/m2m_intermediary/models.py new file mode 100644 index 00000000..3e73164e --- /dev/null +++ b/tests/m2m_intermediary/models.py @@ -0,0 +1,43 @@ +""" +Many-to-many relationships via an intermediary table + +For many-to-many relationships that need extra fields on the intermediary +table, use an intermediary model. + +In this example, an ``Article`` can have multiple ``Reporter`` objects, and +each ``Article``-``Reporter`` combination (a ``Writer``) has a ``position`` +field, which specifies the ``Reporter``'s position for the given article +(e.g. "Staff writer"). +""" +from __future__ import unicode_literals + +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +@python_2_unicode_compatible +class Reporter(models.Model): + first_name = models.CharField(max_length=30) + last_name = models.CharField(max_length=30) + + def __str__(self): + return "%s %s" % (self.first_name, self.last_name) + + +@python_2_unicode_compatible +class Article(models.Model): + headline = models.CharField(max_length=100) + pub_date = models.DateField() + + def __str__(self): + return self.headline + + +@python_2_unicode_compatible +class Writer(models.Model): + reporter = models.ForeignKey(Reporter, models.CASCADE) + article = models.ForeignKey(Article, models.CASCADE) + position = models.CharField(max_length=100) + + def __str__(self): + return '%s (%s)' % (self.reporter, self.position) diff --git a/tests/m2m_intermediary/tests.py b/tests/m2m_intermediary/tests.py new file mode 100644 index 00000000..ce4e1860 --- /dev/null +++ b/tests/m2m_intermediary/tests.py @@ -0,0 +1,41 @@ +from __future__ import unicode_literals + +from datetime import datetime + +from django.test import TestCase +from django.utils import six + +from .models import Article, Reporter, Writer + + +class M2MIntermediaryTests(TestCase): + def test_intermeiary(self): + r1 = Reporter.objects.create(first_name="John", last_name="Smith") + r2 = Reporter.objects.create(first_name="Jane", last_name="Doe") + + a = Article.objects.create( + headline="This is a test", pub_date=datetime(2005, 7, 27) + ) + + w1 = Writer.objects.create(reporter=r1, article=a, position="Main writer") + w2 = Writer.objects.create(reporter=r2, article=a, position="Contributor") + + self.assertQuerysetEqual( + a.writer_set.select_related().order_by("-position"), [ + ("John Smith", "Main writer"), + ("Jane Doe", "Contributor"), + ], + lambda w: (six.text_type(w.reporter), w.position) + ) + self.assertEqual(w1.reporter, r1) + self.assertEqual(w2.reporter, r2) + + self.assertEqual(w1.article, a) + self.assertEqual(w2.article, a) + + self.assertQuerysetEqual( + r1.writer_set.all(), [ + ("John Smith", "Main writer") + ], + lambda w: (six.text_type(w.reporter), w.position) + ) diff --git a/tests/m2m_multiple/__init__.py b/tests/m2m_multiple/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/m2m_multiple/models.py b/tests/m2m_multiple/models.py new file mode 100644 index 00000000..a6db9425 --- /dev/null +++ b/tests/m2m_multiple/models.py @@ -0,0 +1,36 @@ +""" +Multiple many-to-many relationships between the same two tables + +In this example, an ``Article`` can have many "primary" ``Category`` objects +and many "secondary" ``Category`` objects. + +Set ``related_name`` to designate what the reverse relationship is called. +""" + +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +@python_2_unicode_compatible +class Category(models.Model): + name = models.CharField(max_length=20) + + class Meta: + ordering = ('name',) + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class Article(models.Model): + headline = models.CharField(max_length=50) + pub_date = models.DateTimeField() + primary_categories = models.ManyToManyField(Category, related_name='primary_article_set') + secondary_categories = models.ManyToManyField(Category, related_name='secondary_article_set') + + class Meta: + ordering = ('pub_date',) + + def __str__(self): + return self.headline diff --git a/tests/m2m_multiple/tests.py b/tests/m2m_multiple/tests.py new file mode 100644 index 00000000..9d605423 --- /dev/null +++ b/tests/m2m_multiple/tests.py @@ -0,0 +1,86 @@ +from __future__ import unicode_literals + +from datetime import datetime + +from django.test import TestCase + +from .models import Article, Category + + +class M2MMultipleTests(TestCase): + def test_multiple(self): + c1, c2, c3, c4 = [ + Category.objects.create(name=name) + for name in ["Sports", "News", "Crime", "Life"] + ] + + a1 = Article.objects.create( + headline="Parrot steals", pub_date=datetime(2005, 11, 27) + ) + a1.primary_categories.add(c2, c3) + a1.secondary_categories.add(c4) + + a2 = Article.objects.create( + headline="Parrot runs", pub_date=datetime(2005, 11, 28) + ) + a2.primary_categories.add(c1, c2) + a2.secondary_categories.add(c4) + + self.assertQuerysetEqual( + a1.primary_categories.all(), [ + "Crime", + "News", + ], + lambda c: c.name + ) + self.assertQuerysetEqual( + a2.primary_categories.all(), [ + "News", + "Sports", + ], + lambda c: c.name + ) + self.assertQuerysetEqual( + a1.secondary_categories.all(), [ + "Life", + ], + lambda c: c.name + ) + self.assertQuerysetEqual( + c1.primary_article_set.all(), [ + "Parrot runs", + ], + lambda a: a.headline + ) + self.assertQuerysetEqual( + c1.secondary_article_set.all(), [] + ) + self.assertQuerysetEqual( + c2.primary_article_set.all(), [ + "Parrot steals", + "Parrot runs", + ], + lambda a: a.headline + ) + self.assertQuerysetEqual( + c2.secondary_article_set.all(), [] + ) + self.assertQuerysetEqual( + c3.primary_article_set.all(), [ + "Parrot steals", + ], + lambda a: a.headline + ) + self.assertQuerysetEqual( + c3.secondary_article_set.all(), [] + ) + self.assertQuerysetEqual( + c4.primary_article_set.all(), [] + ) + self.assertQuerysetEqual( + c4.secondary_article_set.all(), [ + "Parrot steals", + "Parrot runs", + ], + lambda a: a.headline + ) diff --git a/tests/m2m_recursive/__init__.py b/tests/m2m_recursive/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/m2m_recursive/models.py b/tests/m2m_recursive/models.py new file mode 100644 index 00000000..d224b3d5 --- /dev/null +++ b/tests/m2m_recursive/models.py @@ -0,0 +1,30 @@ +""" +Many-to-many relationships between the same two tables + +In this example, a ``Person`` can have many friends, who are also ``Person`` +objects. Friendship is a symmetrical relationship - if I am your friend, you +are my friend. Here, ``friends`` is an example of a symmetrical +``ManyToManyField``. + +A ``Person`` can also have many idols - but while I may idolize you, you may +not think the same of me. Here, ``idols`` is an example of a non-symmetrical +``ManyToManyField``. Only recursive ``ManyToManyField`` fields may be +non-symmetrical, and they are symmetrical by default. + +This test validates that the many-to-many table is created using a mangled name +if there is a name clash, and tests that symmetry is preserved where +appropriate. +""" + +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +@python_2_unicode_compatible +class Person(models.Model): + name = models.CharField(max_length=20) + friends = models.ManyToManyField('self') + idols = models.ManyToManyField('self', symmetrical=False, related_name='stalkers') + + def __str__(self): + return self.name diff --git a/tests/m2m_recursive/tests.py b/tests/m2m_recursive/tests.py new file mode 100644 index 00000000..c6573800 --- /dev/null +++ b/tests/m2m_recursive/tests.py @@ -0,0 +1,183 @@ +from __future__ import unicode_literals + +from operator import attrgetter + +from django.test import TestCase + +from .models import Person + + +class RecursiveM2MTests(TestCase): + def setUp(self): + self.a, self.b, self.c, self.d = [ + Person.objects.create(name=name) + for name in ["Anne", "Bill", "Chuck", "David"] + ] + + # Anne is friends with Bill and Chuck + self.a.friends.add(self.b, self.c) + + # David is friends with Anne and Chuck - add in reverse direction + self.d.friends.add(self.a, self.c) + + def test_recursive_m2m_all(self): + # Who is friends with Anne? + self.assertQuerysetEqual( + self.a.friends.all(), [ + "Bill", + "Chuck", + "David" + ], + attrgetter("name"), + ordered=False + ) + # Who is friends with Bill? + self.assertQuerysetEqual( + self.b.friends.all(), [ + "Anne", + ], + attrgetter("name") + ) + # Who is friends with Chuck? + self.assertQuerysetEqual( + self.c.friends.all(), [ + "Anne", + "David" + ], + attrgetter("name"), + ordered=False + ) + # Who is friends with David? + self.assertQuerysetEqual( + self.d.friends.all(), [ + "Anne", + "Chuck", + ], + attrgetter("name"), + ordered=False + ) + + def test_recursive_m2m_reverse_add(self): + # Bill is already friends with Anne - add Anne again, but in the + # reverse direction + self.b.friends.add(self.a) + + # Who is friends with Anne? + self.assertQuerysetEqual( + self.a.friends.all(), [ + "Bill", + "Chuck", + "David", + ], + attrgetter("name"), + ordered=False + ) + # Who is friends with Bill? + self.assertQuerysetEqual( + self.b.friends.all(), [ + "Anne", + ], + attrgetter("name") + ) + + def test_recursive_m2m_remove(self): + # Remove Anne from Bill's friends + self.b.friends.remove(self.a) + + # Who is friends with Anne? + self.assertQuerysetEqual( + self.a.friends.all(), [ + "Chuck", + "David", + ], + attrgetter("name"), + ordered=False + ) + # Who is friends with Bill? + self.assertQuerysetEqual( + self.b.friends.all(), [] + ) + + def test_recursive_m2m_clear(self): + # Clear Anne's group of friends + self.a.friends.clear() + + # Who is friends with Anne? + self.assertQuerysetEqual( + self.a.friends.all(), [] + ) + + # Reverse relationships should also be gone + # Who is friends with Chuck? + self.assertQuerysetEqual( + self.c.friends.all(), [ + "David", + ], + attrgetter("name") + ) + + # Who is friends with David? + self.assertQuerysetEqual( + self.d.friends.all(), [ + "Chuck", + ], + attrgetter("name") + ) + + def test_recursive_m2m_add_via_related_name(self): + # David is idolized by Anne and Chuck - add in reverse direction + self.d.stalkers.add(self.a) + + # Who are Anne's idols? + self.assertQuerysetEqual( + self.a.idols.all(), [ + "David", + ], + attrgetter("name"), + ordered=False + ) + # Who is stalking Anne? + self.assertQuerysetEqual( + self.a.stalkers.all(), [], + attrgetter("name") + ) + + def test_recursive_m2m_add_in_both_directions(self): + """Adding the same relation twice results in a single relation.""" + # Ann idolizes David + self.a.idols.add(self.d) + + # David is idolized by Anne + self.d.stalkers.add(self.a) + + # Who are Anne's idols? + self.assertQuerysetEqual( + self.a.idols.all(), [ + "David", + ], + attrgetter("name"), + ordered=False + ) + # As the assertQuerysetEqual uses a set for comparison, + # check we've only got David listed once + self.assertEqual(self.a.idols.all().count(), 1) + + def test_recursive_m2m_related_to_self(self): + # Ann idolizes herself + self.a.idols.add(self.a) + + # Who are Anne's idols? + self.assertQuerysetEqual( + self.a.idols.all(), [ + "Anne", + ], + attrgetter("name"), + ordered=False + ) + # Who is stalking Anne? + self.assertQuerysetEqual( + self.a.stalkers.all(), [ + "Anne", + ], + attrgetter("name") + ) diff --git a/tests/m2m_regress/__init__.py b/tests/m2m_regress/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/m2m_regress/models.py b/tests/m2m_regress/models.py new file mode 100644 index 00000000..57f02b8f --- /dev/null +++ b/tests/m2m_regress/models.py @@ -0,0 +1,100 @@ +from django.contrib.auth import models as auth +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +# No related name is needed here, since symmetrical relations are not +# explicitly reversible. +@python_2_unicode_compatible +class SelfRefer(models.Model): + name = models.CharField(max_length=10) + references = models.ManyToManyField('self') + related = models.ManyToManyField('self') + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class Tag(models.Model): + name = models.CharField(max_length=10) + + def __str__(self): + return self.name + + +# Regression for #11956 -- a many to many to the base class +@python_2_unicode_compatible +class TagCollection(Tag): + tags = models.ManyToManyField(Tag, related_name='tag_collections') + + def __str__(self): + return self.name + + +# A related_name is required on one of the ManyToManyField entries here because +# they are both addressable as reverse relations from Tag. +@python_2_unicode_compatible +class Entry(models.Model): + name = models.CharField(max_length=10) + topics = models.ManyToManyField(Tag) + related = models.ManyToManyField(Tag, related_name="similar") + + def __str__(self): + return self.name + + +# Two models both inheriting from a base model with a self-referential m2m field +class SelfReferChild(SelfRefer): + pass + + +class SelfReferChildSibling(SelfRefer): + pass + + +# Many-to-Many relation between models, where one of the PK's isn't an Autofield +@python_2_unicode_compatible +class Line(models.Model): + name = models.CharField(max_length=100) + + def __str__(self): + return self.name + + +class Worksheet(models.Model): + id = models.CharField(primary_key=True, max_length=100) + lines = models.ManyToManyField(Line, blank=True) + + +# Regression for #11226 -- A model with the same name that another one to +# which it has a m2m relation. This shouldn't cause a name clash between +# the automatically created m2m intermediary table FK field names when +# running migrate +class User(models.Model): + name = models.CharField(max_length=30) + friends = models.ManyToManyField(auth.User) + + +class BadModelWithSplit(models.Model): + name = models.CharField(max_length=1) + + def split(self): + raise RuntimeError('split should not be called') + + class Meta: + abstract = True + + +class RegressionModelSplit(BadModelWithSplit): + """ + Model with a split method should not cause an error in add_lazy_relation + """ + others = models.ManyToManyField('self') + + +# Regression for #24505 -- Two ManyToManyFields with the same "to" model +# and related_name set to '+'. +class Post(models.Model): + primary_lines = models.ManyToManyField(Line, related_name='+') + secondary_lines = models.ManyToManyField(Line, related_name='+') diff --git a/tests/m2m_regress/tests.py b/tests/m2m_regress/tests.py new file mode 100644 index 00000000..3c882c59 --- /dev/null +++ b/tests/m2m_regress/tests.py @@ -0,0 +1,122 @@ +from __future__ import unicode_literals + +from django.core.exceptions import FieldError +from django.test import TestCase + +from .models import ( + Entry, Line, Post, RegressionModelSplit, SelfRefer, SelfReferChild, + SelfReferChildSibling, Tag, TagCollection, Worksheet, +) + + +class M2MRegressionTests(TestCase): + def test_multiple_m2m(self): + # Multiple m2m references to model must be distinguished when + # accessing the relations through an instance attribute. + + s1 = SelfRefer.objects.create(name='s1') + s2 = SelfRefer.objects.create(name='s2') + s3 = SelfRefer.objects.create(name='s3') + s1.references.add(s2) + s1.related.add(s3) + + e1 = Entry.objects.create(name='e1') + t1 = Tag.objects.create(name='t1') + t2 = Tag.objects.create(name='t2') + + e1.topics.add(t1) + e1.related.add(t2) + + self.assertQuerysetEqual(s1.references.all(), [""]) + self.assertQuerysetEqual(s1.related.all(), [""]) + + self.assertQuerysetEqual(e1.topics.all(), [""]) + self.assertQuerysetEqual(e1.related.all(), [""]) + + def test_internal_related_name_not_in_error_msg(self): + # The secret internal related names for self-referential many-to-many + # fields shouldn't appear in the list when an error is made. + self.assertRaisesMessage( + FieldError, + "Choices are: id, name, references, related, selfreferchild, selfreferchildsibling", + lambda: SelfRefer.objects.filter(porcupine='fred') + ) + + def test_m2m_inheritance_symmetry(self): + # Test to ensure that the relationship between two inherited models + # with a self-referential m2m field maintains symmetry + + sr_child = SelfReferChild(name="Hanna") + sr_child.save() + + sr_sibling = SelfReferChildSibling(name="Beth") + sr_sibling.save() + sr_child.related.add(sr_sibling) + + self.assertQuerysetEqual(sr_child.related.all(), [""]) + self.assertQuerysetEqual(sr_sibling.related.all(), [""]) + + def test_m2m_pk_field_type(self): + # Regression for #11311 - The primary key for models in a m2m relation + # doesn't have to be an AutoField + + w = Worksheet(id='abc') + w.save() + w.delete() + + def test_add_m2m_with_base_class(self): + # Regression for #11956 -- You can add an object to a m2m with the + # base class without causing integrity errors + + t1 = Tag.objects.create(name='t1') + t2 = Tag.objects.create(name='t2') + + c1 = TagCollection.objects.create(name='c1') + c1.tags.set([t1, t2]) + c1 = TagCollection.objects.get(name='c1') + + self.assertQuerysetEqual(c1.tags.all(), ["", ""], ordered=False) + self.assertQuerysetEqual(t1.tag_collections.all(), [""]) + + def test_manager_class_caching(self): + e1 = Entry.objects.create() + e2 = Entry.objects.create() + t1 = Tag.objects.create() + t2 = Tag.objects.create() + + # Get same manager twice in a row: + self.assertIs(t1.entry_set.__class__, t1.entry_set.__class__) + self.assertIs(e1.topics.__class__, e1.topics.__class__) + + # Get same manager for different instances + self.assertIs(e1.topics.__class__, e2.topics.__class__) + self.assertIs(t1.entry_set.__class__, t2.entry_set.__class__) + + def test_m2m_abstract_split(self): + # Regression for #19236 - an abstract class with a 'split' method + # causes a TypeError in add_lazy_relation + m1 = RegressionModelSplit(name='1') + m1.save() + + def test_assigning_invalid_data_to_m2m_doesnt_clear_existing_relations(self): + t1 = Tag.objects.create(name='t1') + t2 = Tag.objects.create(name='t2') + c1 = TagCollection.objects.create(name='c1') + c1.tags.set([t1, t2]) + + with self.assertRaises(TypeError): + c1.tags.set(7) + + c1.refresh_from_db() + self.assertQuerysetEqual(c1.tags.order_by('name'), ["", ""]) + + def test_multiple_forwards_only_m2m(self): + # Regression for #24505 - Multiple ManyToManyFields to same "to" + # model with related_name set to '+'. + foo = Line.objects.create(name='foo') + bar = Line.objects.create(name='bar') + post = Post.objects.create() + post.primary_lines.add(foo) + post.secondary_lines.add(bar) + self.assertQuerysetEqual(post.primary_lines.all(), ['']) + self.assertQuerysetEqual(post.secondary_lines.all(), ['']) diff --git a/tests/m2m_signals/__init__.py b/tests/m2m_signals/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/m2m_signals/models.py b/tests/m2m_signals/models.py new file mode 100644 index 00000000..e4110ccf --- /dev/null +++ b/tests/m2m_signals/models.py @@ -0,0 +1,43 @@ +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +@python_2_unicode_compatible +class Part(models.Model): + name = models.CharField(max_length=20) + + class Meta: + ordering = ('name',) + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class Car(models.Model): + name = models.CharField(max_length=20) + default_parts = models.ManyToManyField(Part) + optional_parts = models.ManyToManyField(Part, related_name='cars_optional') + + class Meta: + ordering = ('name',) + + def __str__(self): + return self.name + + +class SportsCar(Car): + price = models.IntegerField() + + +@python_2_unicode_compatible +class Person(models.Model): + name = models.CharField(max_length=20) + fans = models.ManyToManyField('self', related_name='idols', symmetrical=False) + friends = models.ManyToManyField('self') + + class Meta: + ordering = ('name',) + + def __str__(self): + return self.name diff --git a/tests/m2m_signals/tests.py b/tests/m2m_signals/tests.py new file mode 100644 index 00000000..834897eb --- /dev/null +++ b/tests/m2m_signals/tests.py @@ -0,0 +1,463 @@ +""" +Testing signals emitted on changing m2m relations. +""" + +from django.db import models +from django.test import TestCase + +from .models import Car, Part, Person, SportsCar + + +class ManyToManySignalsTest(TestCase): + def m2m_changed_signal_receiver(self, signal, sender, **kwargs): + message = { + 'instance': kwargs['instance'], + 'action': kwargs['action'], + 'reverse': kwargs['reverse'], + 'model': kwargs['model'], + } + if kwargs['pk_set']: + message['objects'] = list( + kwargs['model'].objects.filter(pk__in=kwargs['pk_set']) + ) + self.m2m_changed_messages.append(message) + + def setUp(self): + self.m2m_changed_messages = [] + + self.vw = Car.objects.create(name='VW') + self.bmw = Car.objects.create(name='BMW') + self.toyota = Car.objects.create(name='Toyota') + + self.wheelset = Part.objects.create(name='Wheelset') + self.doors = Part.objects.create(name='Doors') + self.engine = Part.objects.create(name='Engine') + self.airbag = Part.objects.create(name='Airbag') + self.sunroof = Part.objects.create(name='Sunroof') + + self.alice = Person.objects.create(name='Alice') + self.bob = Person.objects.create(name='Bob') + self.chuck = Person.objects.create(name='Chuck') + self.daisy = Person.objects.create(name='Daisy') + + def tearDown(self): + # disconnect all signal handlers + models.signals.m2m_changed.disconnect( + self.m2m_changed_signal_receiver, Car.default_parts.through + ) + models.signals.m2m_changed.disconnect( + self.m2m_changed_signal_receiver, Car.optional_parts.through + ) + models.signals.m2m_changed.disconnect( + self.m2m_changed_signal_receiver, Person.fans.through + ) + models.signals.m2m_changed.disconnect( + self.m2m_changed_signal_receiver, Person.friends.through + ) + + def _initialize_signal_car(self, add_default_parts_before_set_signal=False): + """ Install a listener on the two m2m relations. """ + models.signals.m2m_changed.connect( + self.m2m_changed_signal_receiver, Car.optional_parts.through + ) + if add_default_parts_before_set_signal: + # adding a default part to our car - no signal listener installed + self.vw.default_parts.add(self.sunroof) + models.signals.m2m_changed.connect( + self.m2m_changed_signal_receiver, Car.default_parts.through + ) + + def test_m2m_relations_add_remove_clear(self): + expected_messages = [] + + self._initialize_signal_car(add_default_parts_before_set_signal=True) + + self.vw.default_parts.add(self.wheelset, self.doors, self.engine) + expected_messages.append({ + 'instance': self.vw, + 'action': 'pre_add', + 'reverse': False, + 'model': Part, + 'objects': [self.doors, self.engine, self.wheelset], + }) + expected_messages.append({ + 'instance': self.vw, + 'action': 'post_add', + 'reverse': False, + 'model': Part, + 'objects': [self.doors, self.engine, self.wheelset], + }) + self.assertEqual(self.m2m_changed_messages, expected_messages) + + # give the BMW and Toyota some doors as well + self.doors.car_set.add(self.bmw, self.toyota) + expected_messages.append({ + 'instance': self.doors, + 'action': 'pre_add', + 'reverse': True, + 'model': Car, + 'objects': [self.bmw, self.toyota], + }) + expected_messages.append({ + 'instance': self.doors, + 'action': 'post_add', + 'reverse': True, + 'model': Car, + 'objects': [self.bmw, self.toyota], + }) + self.assertEqual(self.m2m_changed_messages, expected_messages) + + def test_m2m_relations_signals_remove_relation(self): + self._initialize_signal_car() + # remove the engine from the self.vw and the airbag (which is not set + # but is returned) + self.vw.default_parts.remove(self.engine, self.airbag) + self.assertEqual(self.m2m_changed_messages, [ + { + 'instance': self.vw, + 'action': 'pre_remove', + 'reverse': False, + 'model': Part, + 'objects': [self.airbag, self.engine], + }, { + 'instance': self.vw, + 'action': 'post_remove', + 'reverse': False, + 'model': Part, + 'objects': [self.airbag, self.engine], + } + ]) + + def test_m2m_relations_signals_give_the_self_vw_some_optional_parts(self): + expected_messages = [] + + self._initialize_signal_car() + + # give the self.vw some optional parts (second relation to same model) + self.vw.optional_parts.add(self.airbag, self.sunroof) + expected_messages.append({ + 'instance': self.vw, + 'action': 'pre_add', + 'reverse': False, + 'model': Part, + 'objects': [self.airbag, self.sunroof], + }) + expected_messages.append({ + 'instance': self.vw, + 'action': 'post_add', + 'reverse': False, + 'model': Part, + 'objects': [self.airbag, self.sunroof], + }) + self.assertEqual(self.m2m_changed_messages, expected_messages) + + # add airbag to all the cars (even though the self.vw already has one) + self.airbag.cars_optional.add(self.vw, self.bmw, self.toyota) + expected_messages.append({ + 'instance': self.airbag, + 'action': 'pre_add', + 'reverse': True, + 'model': Car, + 'objects': [self.bmw, self.toyota], + }) + expected_messages.append({ + 'instance': self.airbag, + 'action': 'post_add', + 'reverse': True, + 'model': Car, + 'objects': [self.bmw, self.toyota], + }) + self.assertEqual(self.m2m_changed_messages, expected_messages) + + def test_m2m_relations_signals_reverse_relation_with_custom_related_name(self): + self._initialize_signal_car() + # remove airbag from the self.vw (reverse relation with custom + # related_name) + self.airbag.cars_optional.remove(self.vw) + self.assertEqual(self.m2m_changed_messages, [ + { + 'instance': self.airbag, + 'action': 'pre_remove', + 'reverse': True, + 'model': Car, + 'objects': [self.vw], + }, { + 'instance': self.airbag, + 'action': 'post_remove', + 'reverse': True, + 'model': Car, + 'objects': [self.vw], + } + ]) + + def test_m2m_relations_signals_clear_all_parts_of_the_self_vw(self): + self._initialize_signal_car() + # clear all parts of the self.vw + self.vw.default_parts.clear() + self.assertEqual(self.m2m_changed_messages, [ + { + 'instance': self.vw, + 'action': 'pre_clear', + 'reverse': False, + 'model': Part, + }, { + 'instance': self.vw, + 'action': 'post_clear', + 'reverse': False, + 'model': Part, + } + ]) + + def test_m2m_relations_signals_all_the_doors_off_of_cars(self): + self._initialize_signal_car() + # take all the doors off of cars + self.doors.car_set.clear() + self.assertEqual(self.m2m_changed_messages, [ + { + 'instance': self.doors, + 'action': 'pre_clear', + 'reverse': True, + 'model': Car, + }, { + 'instance': self.doors, + 'action': 'post_clear', + 'reverse': True, + 'model': Car, + } + ]) + + def test_m2m_relations_signals_reverse_relation(self): + self._initialize_signal_car() + # take all the airbags off of cars (clear reverse relation with custom + # related_name) + self.airbag.cars_optional.clear() + self.assertEqual(self.m2m_changed_messages, [ + { + 'instance': self.airbag, + 'action': 'pre_clear', + 'reverse': True, + 'model': Car, + }, { + 'instance': self.airbag, + 'action': 'post_clear', + 'reverse': True, + 'model': Car, + } + ]) + + def test_m2m_relations_signals_alternative_ways(self): + expected_messages = [] + + self._initialize_signal_car() + + # alternative ways of setting relation: + self.vw.default_parts.create(name='Windows') + p6 = Part.objects.get(name='Windows') + expected_messages.append({ + 'instance': self.vw, + 'action': 'pre_add', + 'reverse': False, + 'model': Part, + 'objects': [p6], + }) + expected_messages.append({ + 'instance': self.vw, + 'action': 'post_add', + 'reverse': False, + 'model': Part, + 'objects': [p6], + }) + self.assertEqual(self.m2m_changed_messages, expected_messages) + + # direct assignment clears the set first, then adds + self.vw.default_parts.set([self.wheelset, self.doors, self.engine]) + expected_messages.append({ + 'instance': self.vw, + 'action': 'pre_remove', + 'reverse': False, + 'model': Part, + 'objects': [p6], + }) + expected_messages.append({ + 'instance': self.vw, + 'action': 'post_remove', + 'reverse': False, + 'model': Part, + 'objects': [p6], + }) + expected_messages.append({ + 'instance': self.vw, + 'action': 'pre_add', + 'reverse': False, + 'model': Part, + 'objects': [self.doors, self.engine, self.wheelset], + }) + expected_messages.append({ + 'instance': self.vw, + 'action': 'post_add', + 'reverse': False, + 'model': Part, + 'objects': [self.doors, self.engine, self.wheelset], + }) + self.assertEqual(self.m2m_changed_messages, expected_messages) + + def test_m2m_relations_signals_clearing_removing(self): + expected_messages = [] + + self._initialize_signal_car(add_default_parts_before_set_signal=True) + + # set by clearing. + self.vw.default_parts.set([self.wheelset, self.doors, self.engine], clear=True) + expected_messages.append({ + 'instance': self.vw, + 'action': 'pre_clear', + 'reverse': False, + 'model': Part, + }) + expected_messages.append({ + 'instance': self.vw, + 'action': 'post_clear', + 'reverse': False, + 'model': Part, + }) + expected_messages.append({ + 'instance': self.vw, + 'action': 'pre_add', + 'reverse': False, + 'model': Part, + 'objects': [self.doors, self.engine, self.wheelset], + }) + expected_messages.append({ + 'instance': self.vw, + 'action': 'post_add', + 'reverse': False, + 'model': Part, + 'objects': [self.doors, self.engine, self.wheelset], + }) + self.assertEqual(self.m2m_changed_messages, expected_messages) + + # set by only removing what's necessary. + self.vw.default_parts.set([self.wheelset, self.doors], clear=False) + expected_messages.append({ + 'instance': self.vw, + 'action': 'pre_remove', + 'reverse': False, + 'model': Part, + 'objects': [self.engine], + }) + expected_messages.append({ + 'instance': self.vw, + 'action': 'post_remove', + 'reverse': False, + 'model': Part, + 'objects': [self.engine], + }) + self.assertEqual(self.m2m_changed_messages, expected_messages) + + def test_m2m_relations_signals_when_inheritance(self): + expected_messages = [] + + self._initialize_signal_car(add_default_parts_before_set_signal=True) + + # Signals still work when model inheritance is involved + c4 = SportsCar.objects.create(name='Bugatti', price='1000000') + c4b = Car.objects.get(name='Bugatti') + c4.default_parts.set([self.doors]) + expected_messages.append({ + 'instance': c4, + 'action': 'pre_add', + 'reverse': False, + 'model': Part, + 'objects': [self.doors], + }) + expected_messages.append({ + 'instance': c4, + 'action': 'post_add', + 'reverse': False, + 'model': Part, + 'objects': [self.doors], + }) + self.assertEqual(self.m2m_changed_messages, expected_messages) + + self.engine.car_set.add(c4) + expected_messages.append({ + 'instance': self.engine, + 'action': 'pre_add', + 'reverse': True, + 'model': Car, + 'objects': [c4b], + }) + expected_messages.append({ + 'instance': self.engine, + 'action': 'post_add', + 'reverse': True, + 'model': Car, + 'objects': [c4b], + }) + self.assertEqual(self.m2m_changed_messages, expected_messages) + + def _initialize_signal_person(self): + # Install a listener on the two m2m relations. + models.signals.m2m_changed.connect( + self.m2m_changed_signal_receiver, Person.fans.through + ) + models.signals.m2m_changed.connect( + self.m2m_changed_signal_receiver, Person.friends.through + ) + + def test_m2m_relations_with_self_add_friends(self): + self._initialize_signal_person() + self.alice.friends.set([self.bob, self.chuck]) + self.assertEqual(self.m2m_changed_messages, [ + { + 'instance': self.alice, + 'action': 'pre_add', + 'reverse': False, + 'model': Person, + 'objects': [self.bob, self.chuck], + }, { + 'instance': self.alice, + 'action': 'post_add', + 'reverse': False, + 'model': Person, + 'objects': [self.bob, self.chuck], + } + ]) + + def test_m2m_relations_with_self_add_fan(self): + self._initialize_signal_person() + self.alice.fans.set([self.daisy]) + self.assertEqual(self.m2m_changed_messages, [ + { + 'instance': self.alice, + 'action': 'pre_add', + 'reverse': False, + 'model': Person, + 'objects': [self.daisy], + }, { + 'instance': self.alice, + 'action': 'post_add', + 'reverse': False, + 'model': Person, + 'objects': [self.daisy], + } + ]) + + def test_m2m_relations_with_self_add_idols(self): + self._initialize_signal_person() + self.chuck.idols.set([self.alice, self.bob]) + self.assertEqual(self.m2m_changed_messages, [ + { + 'instance': self.chuck, + 'action': 'pre_add', + 'reverse': True, + 'model': Person, + 'objects': [self.alice, self.bob], + }, { + 'instance': self.chuck, + 'action': 'post_add', + 'reverse': True, + 'model': Person, + 'objects': [self.alice, self.bob], + } + ]) diff --git a/tests/m2m_through/__init__.py b/tests/m2m_through/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/m2m_through/models.py b/tests/m2m_through/models.py new file mode 100644 index 00000000..dab3be51 --- /dev/null +++ b/tests/m2m_through/models.py @@ -0,0 +1,156 @@ +from datetime import datetime + +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +# M2M described on one of the models +@python_2_unicode_compatible +class Person(models.Model): + name = models.CharField(max_length=128) + + class Meta: + ordering = ('name',) + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class Group(models.Model): + name = models.CharField(max_length=128) + members = models.ManyToManyField(Person, through='Membership') + custom_members = models.ManyToManyField(Person, through='CustomMembership', related_name="custom") + nodefaultsnonulls = models.ManyToManyField( + Person, + through='TestNoDefaultsOrNulls', + related_name="testnodefaultsnonulls", + ) + + class Meta: + ordering = ('name',) + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class Membership(models.Model): + person = models.ForeignKey(Person, models.CASCADE) + group = models.ForeignKey(Group, models.CASCADE) + date_joined = models.DateTimeField(default=datetime.now) + invite_reason = models.CharField(max_length=64, null=True) + + class Meta: + ordering = ('date_joined', 'invite_reason', 'group') + + def __str__(self): + return "%s is a member of %s" % (self.person.name, self.group.name) + + +@python_2_unicode_compatible +class CustomMembership(models.Model): + person = models.ForeignKey( + Person, + models.CASCADE, + db_column="custom_person_column", + related_name="custom_person_related_name", + ) + group = models.ForeignKey(Group, models.CASCADE) + weird_fk = models.ForeignKey(Membership, models.SET_NULL, null=True) + date_joined = models.DateTimeField(default=datetime.now) + + def __str__(self): + return "%s is a member of %s" % (self.person.name, self.group.name) + + class Meta: + db_table = "test_table" + ordering = ["date_joined"] + + +class TestNoDefaultsOrNulls(models.Model): + person = models.ForeignKey(Person, models.CASCADE) + group = models.ForeignKey(Group, models.CASCADE) + nodefaultnonull = models.CharField(max_length=5) + + +@python_2_unicode_compatible +class PersonSelfRefM2M(models.Model): + name = models.CharField(max_length=5) + friends = models.ManyToManyField('self', through="Friendship", symmetrical=False) + + def __str__(self): + return self.name + + +class Friendship(models.Model): + first = models.ForeignKey(PersonSelfRefM2M, models.CASCADE, related_name="rel_from_set") + second = models.ForeignKey(PersonSelfRefM2M, models.CASCADE, related_name="rel_to_set") + date_friended = models.DateTimeField() + + +# Custom through link fields +@python_2_unicode_compatible +class Event(models.Model): + title = models.CharField(max_length=50) + invitees = models.ManyToManyField( + Person, through='Invitation', + through_fields=('event', 'invitee'), + related_name='events_invited', + ) + + def __str__(self): + return self.title + + +class Invitation(models.Model): + event = models.ForeignKey(Event, models.CASCADE, related_name='invitations') + # field order is deliberately inverted. the target field is "invitee". + inviter = models.ForeignKey(Person, models.CASCADE, related_name='invitations_sent') + invitee = models.ForeignKey(Person, models.CASCADE, related_name='invitations') + + +@python_2_unicode_compatible +class Employee(models.Model): + name = models.CharField(max_length=5) + subordinates = models.ManyToManyField( + 'self', + through="Relationship", + through_fields=('source', 'target'), + symmetrical=False, + ) + + class Meta: + ordering = ('pk',) + + def __str__(self): + return self.name + + +class Relationship(models.Model): + # field order is deliberately inverted. + another = models.ForeignKey(Employee, models.SET_NULL, related_name="rel_another_set", null=True) + target = models.ForeignKey(Employee, models.CASCADE, related_name="rel_target_set") + source = models.ForeignKey(Employee, models.CASCADE, related_name="rel_source_set") + + +class Ingredient(models.Model): + iname = models.CharField(max_length=20, unique=True) + + class Meta: + ordering = ('iname',) + + +class Recipe(models.Model): + rname = models.CharField(max_length=20, unique=True) + ingredients = models.ManyToManyField( + Ingredient, through='RecipeIngredient', related_name='recipes', + ) + + class Meta: + ordering = ('rname',) + + +class RecipeIngredient(models.Model): + ingredient = models.ForeignKey(Ingredient, models.CASCADE, to_field='iname') + recipe = models.ForeignKey(Recipe, models.CASCADE, to_field='rname') diff --git a/tests/m2m_through/tests.py b/tests/m2m_through/tests.py new file mode 100644 index 00000000..47cbbeec --- /dev/null +++ b/tests/m2m_through/tests.py @@ -0,0 +1,472 @@ +from __future__ import unicode_literals + +from datetime import datetime +from operator import attrgetter + +from django.test import TestCase, skipUnlessDBFeature + +from .models import ( + CustomMembership, Employee, Event, Friendship, Group, Ingredient, + Invitation, Membership, Person, PersonSelfRefM2M, Recipe, RecipeIngredient, + Relationship, +) + + +class M2mThroughTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.bob = Person.objects.create(name='Bob') + cls.jim = Person.objects.create(name='Jim') + cls.jane = Person.objects.create(name='Jane') + cls.rock = Group.objects.create(name='Rock') + cls.roll = Group.objects.create(name='Roll') + + def test_retrieve_intermediate_items(self): + Membership.objects.create(person=self.jim, group=self.rock) + Membership.objects.create(person=self.jane, group=self.rock) + + expected = ['Jane', 'Jim'] + self.assertQuerysetEqual( + self.rock.members.all(), + expected, + attrgetter("name") + ) + + def test_get_on_intermediate_model(self): + Membership.objects.create(person=self.jane, group=self.rock) + + queryset = Membership.objects.get(person=self.jane, group=self.rock) + + self.assertEqual( + repr(queryset), + '' + ) + + def test_filter_on_intermediate_model(self): + Membership.objects.create(person=self.jim, group=self.rock) + Membership.objects.create(person=self.jane, group=self.rock) + + queryset = Membership.objects.filter(group=self.rock) + + expected = [ + '', + '', + ] + + self.assertQuerysetEqual( + queryset, + expected + ) + + def test_cannot_use_add_on_m2m_with_intermediary_model(self): + msg = 'Cannot use add() on a ManyToManyField which specifies an intermediary model' + + with self.assertRaisesMessage(AttributeError, msg): + self.rock.members.add(self.bob) + + self.assertQuerysetEqual( + self.rock.members.all(), + [] + ) + + def test_cannot_use_create_on_m2m_with_intermediary_model(self): + msg = 'Cannot use create() on a ManyToManyField which specifies an intermediary model' + + with self.assertRaisesMessage(AttributeError, msg): + self.rock.members.create(name='Annie') + + self.assertQuerysetEqual( + self.rock.members.all(), + [] + ) + + def test_cannot_use_remove_on_m2m_with_intermediary_model(self): + Membership.objects.create(person=self.jim, group=self.rock) + msg = 'Cannot use remove() on a ManyToManyField which specifies an intermediary model' + + with self.assertRaisesMessage(AttributeError, msg): + self.rock.members.remove(self.jim) + + self.assertQuerysetEqual( + self.rock.members.all(), + ['Jim', ], + attrgetter("name") + ) + + def test_cannot_use_setattr_on_m2m_with_intermediary_model(self): + msg = 'Cannot set values on a ManyToManyField which specifies an intermediary model' + members = list(Person.objects.filter(name__in=['Bob', 'Jim'])) + + with self.assertRaisesMessage(AttributeError, msg): + self.rock.members.set(members) + + self.assertQuerysetEqual( + self.rock.members.all(), + [] + ) + + def test_clear_removes_all_the_m2m_relationships(self): + Membership.objects.create(person=self.jim, group=self.rock) + Membership.objects.create(person=self.jane, group=self.rock) + + self.rock.members.clear() + + self.assertQuerysetEqual( + self.rock.members.all(), + [] + ) + + def test_retrieve_reverse_intermediate_items(self): + Membership.objects.create(person=self.jim, group=self.rock) + Membership.objects.create(person=self.jim, group=self.roll) + + expected = ['Rock', 'Roll'] + self.assertQuerysetEqual( + self.jim.group_set.all(), + expected, + attrgetter("name") + ) + + def test_cannot_use_add_on_reverse_m2m_with_intermediary_model(self): + msg = 'Cannot use add() on a ManyToManyField which specifies an intermediary model' + + with self.assertRaisesMessage(AttributeError, msg): + self.bob.group_set.add(self.bob) + + self.assertQuerysetEqual( + self.bob.group_set.all(), + [] + ) + + def test_cannot_use_create_on_reverse_m2m_with_intermediary_model(self): + msg = 'Cannot use create() on a ManyToManyField which specifies an intermediary model' + + with self.assertRaisesMessage(AttributeError, msg): + self.bob.group_set.create(name='Funk') + + self.assertQuerysetEqual( + self.bob.group_set.all(), + [] + ) + + def test_cannot_use_remove_on_reverse_m2m_with_intermediary_model(self): + Membership.objects.create(person=self.bob, group=self.rock) + msg = 'Cannot use remove() on a ManyToManyField which specifies an intermediary model' + + with self.assertRaisesMessage(AttributeError, msg): + self.bob.group_set.remove(self.rock) + + self.assertQuerysetEqual( + self.bob.group_set.all(), + ['Rock', ], + attrgetter('name') + ) + + def test_cannot_use_setattr_on_reverse_m2m_with_intermediary_model(self): + msg = 'Cannot set values on a ManyToManyField which specifies an intermediary model' + members = list(Group.objects.filter(name__in=['Rock', 'Roll'])) + + with self.assertRaisesMessage(AttributeError, msg): + self.bob.group_set.set(members) + + self.assertQuerysetEqual( + self.bob.group_set.all(), + [] + ) + + def test_clear_on_reverse_removes_all_the_m2m_relationships(self): + Membership.objects.create(person=self.jim, group=self.rock) + Membership.objects.create(person=self.jim, group=self.roll) + + self.jim.group_set.clear() + + self.assertQuerysetEqual( + self.jim.group_set.all(), + [] + ) + + def test_query_model_by_attribute_name_of_related_model(self): + Membership.objects.create(person=self.jim, group=self.rock) + Membership.objects.create(person=self.jane, group=self.rock) + Membership.objects.create(person=self.bob, group=self.roll) + Membership.objects.create(person=self.jim, group=self.roll) + Membership.objects.create(person=self.jane, group=self.roll) + + self.assertQuerysetEqual( + Group.objects.filter(members__name='Bob'), + ['Roll', ], + attrgetter("name") + ) + + @skipUnlessDBFeature('supports_microsecond_precision') + def test_order_by_relational_field_through_model(self): + CustomMembership.objects.create(person=self.jim, group=self.rock) + CustomMembership.objects.create(person=self.bob, group=self.rock) + CustomMembership.objects.create(person=self.jane, group=self.roll) + CustomMembership.objects.create(person=self.jim, group=self.roll) + self.assertSequenceEqual( + self.rock.custom_members.order_by('custom_person_related_name'), + [self.jim, self.bob] + ) + self.assertSequenceEqual( + self.roll.custom_members.order_by('custom_person_related_name'), + [self.jane, self.jim] + ) + + def test_query_first_model_by_intermediate_model_attribute(self): + Membership.objects.create( + person=self.jane, group=self.roll, + invite_reason="She was just awesome." + ) + Membership.objects.create( + person=self.jim, group=self.roll, + invite_reason="He is good." + ) + Membership.objects.create(person=self.bob, group=self.roll) + + qs = Group.objects.filter( + membership__invite_reason="She was just awesome." + ) + self.assertQuerysetEqual( + qs, + ['Roll'], + attrgetter("name") + ) + + def test_query_second_model_by_intermediate_model_attribute(self): + Membership.objects.create( + person=self.jane, group=self.roll, + invite_reason="She was just awesome." + ) + Membership.objects.create( + person=self.jim, group=self.roll, + invite_reason="He is good." + ) + Membership.objects.create(person=self.bob, group=self.roll) + + qs = Person.objects.filter( + membership__invite_reason="She was just awesome." + ) + self.assertQuerysetEqual( + qs, + ['Jane'], + attrgetter("name") + ) + + def test_query_model_by_related_model_name(self): + Membership.objects.create(person=self.jim, group=self.rock) + Membership.objects.create(person=self.jane, group=self.rock) + Membership.objects.create(person=self.bob, group=self.roll) + Membership.objects.create(person=self.jim, group=self.roll) + Membership.objects.create(person=self.jane, group=self.roll) + + self.assertQuerysetEqual( + Person.objects.filter(group__name="Rock"), + ['Jane', 'Jim'], + attrgetter("name") + ) + + def test_query_model_by_custom_related_name(self): + CustomMembership.objects.create(person=self.bob, group=self.rock) + CustomMembership.objects.create(person=self.jim, group=self.rock) + + self.assertQuerysetEqual( + Person.objects.filter(custom__name="Rock"), + ['Bob', 'Jim'], + attrgetter("name") + ) + + def test_query_model_by_intermediate_can_return_non_unique_queryset(self): + Membership.objects.create(person=self.jim, group=self.rock) + Membership.objects.create( + person=self.jane, group=self.rock, + date_joined=datetime(2006, 1, 1) + ) + Membership.objects.create( + person=self.bob, group=self.roll, + date_joined=datetime(2004, 1, 1)) + Membership.objects.create(person=self.jim, group=self.roll) + Membership.objects.create( + person=self.jane, group=self.roll, + date_joined=datetime(2004, 1, 1)) + + qs = Person.objects.filter( + membership__date_joined__gt=datetime(2004, 1, 1) + ) + self.assertQuerysetEqual( + qs, + ['Jane', 'Jim', 'Jim'], + attrgetter("name") + ) + + def test_custom_related_name_forward_empty_qs(self): + self.assertQuerysetEqual( + self.rock.custom_members.all(), + [] + ) + + def test_custom_related_name_reverse_empty_qs(self): + self.assertQuerysetEqual( + self.bob.custom.all(), + [] + ) + + def test_custom_related_name_forward_non_empty_qs(self): + CustomMembership.objects.create(person=self.bob, group=self.rock) + CustomMembership.objects.create(person=self.jim, group=self.rock) + + self.assertQuerysetEqual( + self.rock.custom_members.all(), + ['Bob', 'Jim'], + attrgetter("name") + ) + + def test_custom_related_name_reverse_non_empty_qs(self): + CustomMembership.objects.create(person=self.bob, group=self.rock) + CustomMembership.objects.create(person=self.jim, group=self.rock) + + self.assertQuerysetEqual( + self.bob.custom.all(), + ['Rock'], + attrgetter("name") + ) + + def test_custom_related_name_doesnt_conflict_with_fky_related_name(self): + CustomMembership.objects.create(person=self.bob, group=self.rock) + + self.assertQuerysetEqual( + self.bob.custom_person_related_name.all(), + [''] + ) + + def test_through_fields(self): + """ + Relations with intermediary tables with multiple FKs + to the M2M's ``to`` model are possible. + """ + event = Event.objects.create(title='Rockwhale 2014') + Invitation.objects.create(event=event, inviter=self.bob, invitee=self.jim) + Invitation.objects.create(event=event, inviter=self.bob, invitee=self.jane) + self.assertQuerysetEqual( + event.invitees.all(), + ['Jane', 'Jim'], + attrgetter('name') + ) + + +class M2mThroughReferentialTests(TestCase): + def test_self_referential_empty_qs(self): + tony = PersonSelfRefM2M.objects.create(name="Tony") + self.assertQuerysetEqual( + tony.friends.all(), + [] + ) + + def test_self_referential_non_symmetrical_first_side(self): + tony = PersonSelfRefM2M.objects.create(name="Tony") + chris = PersonSelfRefM2M.objects.create(name="Chris") + Friendship.objects.create( + first=tony, second=chris, date_friended=datetime.now() + ) + + self.assertQuerysetEqual( + tony.friends.all(), + ['Chris'], + attrgetter("name") + ) + + def test_self_referential_non_symmetrical_second_side(self): + tony = PersonSelfRefM2M.objects.create(name="Tony") + chris = PersonSelfRefM2M.objects.create(name="Chris") + Friendship.objects.create( + first=tony, second=chris, date_friended=datetime.now() + ) + + self.assertQuerysetEqual( + chris.friends.all(), + [] + ) + + def test_self_referential_non_symmetrical_clear_first_side(self): + tony = PersonSelfRefM2M.objects.create(name="Tony") + chris = PersonSelfRefM2M.objects.create(name="Chris") + Friendship.objects.create( + first=tony, second=chris, date_friended=datetime.now() + ) + + chris.friends.clear() + + self.assertQuerysetEqual( + chris.friends.all(), + [] + ) + + # Since this isn't a symmetrical relation, Tony's friend link still exists. + self.assertQuerysetEqual( + tony.friends.all(), + ['Chris'], + attrgetter("name") + ) + + def test_self_referential_symmetrical(self): + tony = PersonSelfRefM2M.objects.create(name="Tony") + chris = PersonSelfRefM2M.objects.create(name="Chris") + Friendship.objects.create( + first=tony, second=chris, date_friended=datetime.now() + ) + Friendship.objects.create( + first=chris, second=tony, date_friended=datetime.now() + ) + + self.assertQuerysetEqual( + tony.friends.all(), + ['Chris'], + attrgetter("name") + ) + + self.assertQuerysetEqual( + chris.friends.all(), + ['Tony'], + attrgetter("name") + ) + + def test_through_fields_self_referential(self): + john = Employee.objects.create(name='john') + peter = Employee.objects.create(name='peter') + mary = Employee.objects.create(name='mary') + harry = Employee.objects.create(name='harry') + + Relationship.objects.create(source=john, target=peter, another=None) + Relationship.objects.create(source=john, target=mary, another=None) + Relationship.objects.create(source=john, target=harry, another=peter) + + self.assertQuerysetEqual( + john.subordinates.all(), + ['peter', 'mary', 'harry'], + attrgetter('name') + ) + + +class M2mThroughToFieldsTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.pea = Ingredient.objects.create(iname='pea') + cls.potato = Ingredient.objects.create(iname='potato') + cls.tomato = Ingredient.objects.create(iname='tomato') + cls.curry = Recipe.objects.create(rname='curry') + RecipeIngredient.objects.create(recipe=cls.curry, ingredient=cls.potato) + RecipeIngredient.objects.create(recipe=cls.curry, ingredient=cls.pea) + RecipeIngredient.objects.create(recipe=cls.curry, ingredient=cls.tomato) + + def test_retrieval(self): + # Forward retrieval + self.assertSequenceEqual(self.curry.ingredients.all(), [self.pea, self.potato, self.tomato]) + # Backward retrieval + self.assertEqual(self.tomato.recipes.get(), self.curry) + + def test_choices(self): + field = Recipe._meta.get_field('ingredients') + self.assertEqual( + [choice[0] for choice in field.get_choices(include_blank=False)], + ['pea', 'potato', 'tomato'] + ) diff --git a/tests/m2o_recursive/__init__.py b/tests/m2o_recursive/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/m2o_recursive/models.py b/tests/m2o_recursive/models.py new file mode 100644 index 00000000..d62c514a --- /dev/null +++ b/tests/m2o_recursive/models.py @@ -0,0 +1,33 @@ +""" +Relating an object to itself, many-to-one + +To define a many-to-one relationship between a model and itself, use +``ForeignKey('self', ...)``. + +In this example, a ``Category`` is related to itself. That is, each +``Category`` has a parent ``Category``. + +Set ``related_name`` to designate what the reverse relationship is called. +""" + +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +@python_2_unicode_compatible +class Category(models.Model): + name = models.CharField(max_length=20) + parent = models.ForeignKey('self', models.SET_NULL, blank=True, null=True, related_name='child_set') + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class Person(models.Model): + full_name = models.CharField(max_length=20) + mother = models.ForeignKey('self', models.SET_NULL, null=True, related_name='mothers_child_set') + father = models.ForeignKey('self', models.SET_NULL, null=True, related_name='fathers_child_set') + + def __str__(self): + return self.full_name diff --git a/tests/m2o_recursive/tests.py b/tests/m2o_recursive/tests.py new file mode 100644 index 00000000..8e730d48 --- /dev/null +++ b/tests/m2o_recursive/tests.py @@ -0,0 +1,43 @@ +from __future__ import unicode_literals + +from django.test import TestCase + +from .models import Category, Person + + +class ManyToOneRecursiveTests(TestCase): + + def setUp(self): + self.r = Category(id=None, name='Root category', parent=None) + self.r.save() + self.c = Category(id=None, name='Child category', parent=self.r) + self.c.save() + + def test_m2o_recursive(self): + self.assertQuerysetEqual(self.r.child_set.all(), + ['']) + self.assertEqual(self.r.child_set.get(name__startswith='Child').id, self.c.id) + self.assertIsNone(self.r.parent) + self.assertQuerysetEqual(self.c.child_set.all(), []) + self.assertEqual(self.c.parent.id, self.r.id) + + +class MultipleManyToOneRecursiveTests(TestCase): + + def setUp(self): + self.dad = Person(full_name='John Smith Senior', mother=None, father=None) + self.dad.save() + self.mom = Person(full_name='Jane Smith', mother=None, father=None) + self.mom.save() + self.kid = Person(full_name='John Smith Junior', mother=self.mom, father=self.dad) + self.kid.save() + + def test_m2o_recursive2(self): + self.assertEqual(self.kid.mother.id, self.mom.id) + self.assertEqual(self.kid.father.id, self.dad.id) + self.assertQuerysetEqual(self.dad.fathers_child_set.all(), + ['']) + self.assertQuerysetEqual(self.mom.mothers_child_set.all(), + ['']) + self.assertQuerysetEqual(self.kid.mothers_child_set.all(), []) + self.assertQuerysetEqual(self.kid.fathers_child_set.all(), []) diff --git a/tests/many_to_one_null/__init__.py b/tests/many_to_one_null/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/many_to_one_null/models.py b/tests/many_to_one_null/models.py new file mode 100644 index 00000000..2a67623d --- /dev/null +++ b/tests/many_to_one_null/models.py @@ -0,0 +1,37 @@ +""" +Many-to-one relationships that can be null + +To define a many-to-one relationship that can have a null foreign key, use +``ForeignKey()`` with ``null=True`` . +""" + +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +@python_2_unicode_compatible +class Reporter(models.Model): + name = models.CharField(max_length=30) + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class Article(models.Model): + headline = models.CharField(max_length=100) + reporter = models.ForeignKey(Reporter, models.SET_NULL, null=True) + + class Meta: + ordering = ('headline',) + + def __str__(self): + return self.headline + + +class Car(models.Model): + make = models.CharField(max_length=100, null=True, unique=True) + + +class Driver(models.Model): + car = models.ForeignKey(Car, models.SET_NULL, to_field='make', null=True, related_name='drivers') diff --git a/tests/many_to_one_null/tests.py b/tests/many_to_one_null/tests.py new file mode 100644 index 00000000..dc49c61f --- /dev/null +++ b/tests/many_to_one_null/tests.py @@ -0,0 +1,138 @@ +from __future__ import unicode_literals + +from django.test import TestCase + +from .models import Article, Car, Driver, Reporter + + +class ManyToOneNullTests(TestCase): + def setUp(self): + # Create a Reporter. + self.r = Reporter(name='John Smith') + self.r.save() + # Create an Article. + self.a = Article(headline="First", reporter=self.r) + self.a.save() + # Create an Article via the Reporter object. + self.a2 = self.r.article_set.create(headline="Second") + # Create an Article with no Reporter by passing "reporter=None". + self.a3 = Article(headline="Third", reporter=None) + self.a3.save() + # Create another article and reporter + self.r2 = Reporter(name='Paul Jones') + self.r2.save() + self.a4 = self.r2.article_set.create(headline='Fourth') + + def test_get_related(self): + self.assertEqual(self.a.reporter.id, self.r.id) + # Article objects have access to their related Reporter objects. + r = self.a.reporter + self.assertEqual(r.id, self.r.id) + + def test_created_via_related_set(self): + self.assertEqual(self.a2.reporter.id, self.r.id) + + def test_related_set(self): + # Reporter objects have access to their related Article objects. + self.assertQuerysetEqual(self.r.article_set.all(), ['', '']) + self.assertQuerysetEqual(self.r.article_set.filter(headline__startswith='Fir'), ['']) + self.assertEqual(self.r.article_set.count(), 2) + + def test_created_without_related(self): + self.assertIsNone(self.a3.reporter) + # Need to reget a3 to refresh the cache + a3 = Article.objects.get(pk=self.a3.pk) + with self.assertRaises(AttributeError): + getattr(a3.reporter, 'id') + # Accessing an article's 'reporter' attribute returns None + # if the reporter is set to None. + self.assertIsNone(a3.reporter) + # To retrieve the articles with no reporters set, use "reporter__isnull=True". + self.assertQuerysetEqual(Article.objects.filter(reporter__isnull=True), ['']) + # We can achieve the same thing by filtering for the case where the + # reporter is None. + self.assertQuerysetEqual(Article.objects.filter(reporter=None), ['']) + # Set the reporter for the Third article + self.assertQuerysetEqual(self.r.article_set.all(), ['', '']) + self.r.article_set.add(a3) + self.assertQuerysetEqual( + self.r.article_set.all(), + ['', '', ''] + ) + # Remove an article from the set, and check that it was removed. + self.r.article_set.remove(a3) + self.assertQuerysetEqual(self.r.article_set.all(), ['', '']) + self.assertQuerysetEqual(Article.objects.filter(reporter__isnull=True), ['']) + + def test_remove_from_wrong_set(self): + self.assertQuerysetEqual(self.r2.article_set.all(), ['']) + # Try to remove a4 from a set it does not belong to + with self.assertRaises(Reporter.DoesNotExist): + self.r.article_set.remove(self.a4) + self.assertQuerysetEqual(self.r2.article_set.all(), ['']) + + def test_set(self): + # Use manager.set() to allocate ForeignKey. Null is legal, so existing + # members of the set that are not in the assignment set are set to null. + self.r2.article_set.set([self.a2, self.a3]) + self.assertQuerysetEqual(self.r2.article_set.all(), ['', '']) + # Use manager.set(clear=True) + self.r2.article_set.set([self.a3, self.a4], clear=True) + self.assertQuerysetEqual(self.r2.article_set.all(), ['', '']) + # Clear the rest of the set + self.r2.article_set.set([]) + self.assertQuerysetEqual(self.r2.article_set.all(), []) + self.assertQuerysetEqual( + Article.objects.filter(reporter__isnull=True), + ['', '', ''] + ) + + def test_assign_clear_related_set(self): + # Use descriptor assignment to allocate ForeignKey. Null is legal, so + # existing members of the set that are not in the assignment set are + # set to null. + self.r2.article_set.set([self.a2, self.a3]) + self.assertQuerysetEqual(self.r2.article_set.all(), ['', '']) + # Clear the rest of the set + self.r.article_set.clear() + self.assertQuerysetEqual(self.r.article_set.all(), []) + self.assertQuerysetEqual( + Article.objects.filter(reporter__isnull=True), + ['', ''] + ) + + def test_assign_with_queryset(self): + # Querysets used in reverse FK assignments are pre-evaluated + # so their value isn't affected by the clearing operation in + # RelatedManager.set() (#19816). + self.r2.article_set.set([self.a2, self.a3]) + + qs = self.r2.article_set.filter(headline="Second") + self.r2.article_set.set(qs) + + self.assertEqual(1, self.r2.article_set.count()) + self.assertEqual(1, qs.count()) + + def test_add_efficiency(self): + r = Reporter.objects.create() + articles = [] + for _ in range(3): + articles.append(Article.objects.create()) + with self.assertNumQueries(1): + r.article_set.add(*articles) + self.assertEqual(r.article_set.count(), 3) + + def test_clear_efficiency(self): + r = Reporter.objects.create() + for _ in range(3): + r.article_set.create() + with self.assertNumQueries(1): + r.article_set.clear() + self.assertEqual(r.article_set.count(), 0) + + def test_related_null_to_field(self): + c1 = Car.objects.create() + d1 = Driver.objects.create() + self.assertIs(d1.car, None) + with self.assertNumQueries(0): + self.assertEqual(list(c1.drivers.all()), []) diff --git a/tests/migration_test_data_persistence/__init__.py b/tests/migration_test_data_persistence/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/migration_test_data_persistence/migrations/0001_initial.py b/tests/migration_test_data_persistence/migrations/0001_initial.py new file mode 100644 index 00000000..6c19c4c8 --- /dev/null +++ b/tests/migration_test_data_persistence/migrations/0001_initial.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ] + + operations = [ + migrations.CreateModel( + name='Book', + fields=[ + ('id', models.AutoField(verbose_name='ID', primary_key=True, serialize=False, auto_created=True)), + ('title', models.CharField(max_length=100)), + ], + options={ + }, + bases=(models.Model,), + ), + ] diff --git a/tests/migration_test_data_persistence/migrations/0002_add_book.py b/tests/migration_test_data_persistence/migrations/0002_add_book.py new file mode 100644 index 00000000..6ce7fff2 --- /dev/null +++ b/tests/migration_test_data_persistence/migrations/0002_add_book.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +from django.db import migrations + + +def add_book(apps, schema_editor): + apps.get_model("migration_test_data_persistence", "Book").objects.using( + schema_editor.connection.alias, + ).create( + title="I Love Django", + ) + + +class Migration(migrations.Migration): + + dependencies = [("migration_test_data_persistence", "0001_initial")] + + operations = [ + migrations.RunPython( + add_book, + ), + ] diff --git a/tests/migration_test_data_persistence/migrations/__init__.py b/tests/migration_test_data_persistence/migrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/migration_test_data_persistence/models.py b/tests/migration_test_data_persistence/models.py new file mode 100644 index 00000000..c1572d5d --- /dev/null +++ b/tests/migration_test_data_persistence/models.py @@ -0,0 +1,12 @@ +from django.db import models + + +class Book(models.Model): + title = models.CharField(max_length=100) + + +class Unmanaged(models.Model): + title = models.CharField(max_length=100) + + class Meta: + managed = False diff --git a/tests/migration_test_data_persistence/tests.py b/tests/migration_test_data_persistence/tests.py new file mode 100644 index 00000000..862a06c4 --- /dev/null +++ b/tests/migration_test_data_persistence/tests.py @@ -0,0 +1,31 @@ +from django.test import TestCase, TransactionTestCase + +from .models import Book + + +class MigrationDataPersistenceTestCase(TransactionTestCase): + """ + Data loaded in migrations is available if + TransactionTestCase.serialized_rollback = True. + """ + + available_apps = ["migration_test_data_persistence"] + serialized_rollback = True + + def test_persistence(self): + self.assertEqual( + Book.objects.count(), + 1, + ) + + +class MigrationDataNormalPersistenceTestCase(TestCase): + """ + Data loaded in migrations is available on TestCase + """ + + def test_persistence(self): + self.assertEqual( + Book.objects.count(), + 1, + ) diff --git a/tests/migrations2/__init__.py b/tests/migrations2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/migrations2/models.py b/tests/migrations2/models.py new file mode 100644 index 00000000..3ea7a1df --- /dev/null +++ b/tests/migrations2/models.py @@ -0,0 +1 @@ +# Required for migration detection (#22645) diff --git a/tests/migrations2/test_migrations_2/0001_initial.py b/tests/migrations2/test_migrations_2/0001_initial.py new file mode 100644 index 00000000..02cbd97f --- /dev/null +++ b/tests/migrations2/test_migrations_2/0001_initial.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [("migrations", "0002_second")] + + operations = [ + + migrations.CreateModel( + "OtherAuthor", + [ + ("id", models.AutoField(primary_key=True)), + ("name", models.CharField(max_length=255)), + ("slug", models.SlugField(null=True)), + ("age", models.IntegerField(default=0)), + ("silly_field", models.BooleanField(default=False)), + ], + ), + + ] diff --git a/tests/migrations2/test_migrations_2/__init__.py b/tests/migrations2/test_migrations_2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/migrations2/test_migrations_2_first/0001_initial.py b/tests/migrations2/test_migrations_2_first/0001_initial.py new file mode 100644 index 00000000..e31d1d50 --- /dev/null +++ b/tests/migrations2/test_migrations_2_first/0001_initial.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("migrations", "__first__"), + ] + + operations = [ + + migrations.CreateModel( + "OtherAuthor", + [ + ("id", models.AutoField(primary_key=True)), + ("name", models.CharField(max_length=255)), + ("slug", models.SlugField(null=True)), + ("age", models.IntegerField(default=0)), + ("silly_field", models.BooleanField(default=False)), + ], + ), + + ] diff --git a/tests/migrations2/test_migrations_2_first/0002_second.py b/tests/migrations2/test_migrations_2_first/0002_second.py new file mode 100644 index 00000000..a3ca7dac --- /dev/null +++ b/tests/migrations2/test_migrations_2_first/0002_second.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [("migrations2", "0001_initial")] + + operations = [ + + migrations.CreateModel( + "Bookstore", + [ + ("id", models.AutoField(primary_key=True)), + ("name", models.CharField(max_length=255)), + ("slug", models.SlugField(null=True)), + ], + ), + + ] diff --git a/tests/migrations2/test_migrations_2_first/__init__.py b/tests/migrations2/test_migrations_2_first/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/migrations2/test_migrations_2_no_deps/0001_initial.py b/tests/migrations2/test_migrations_2_no_deps/0001_initial.py new file mode 100644 index 00000000..22137065 --- /dev/null +++ b/tests/migrations2/test_migrations_2_no_deps/0001_initial.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [] + + operations = [ + + migrations.CreateModel( + "OtherAuthor", + [ + ("id", models.AutoField(primary_key=True)), + ("name", models.CharField(max_length=255)), + ("slug", models.SlugField(null=True)), + ("age", models.IntegerField(default=0)), + ("silly_field", models.BooleanField(default=False)), + ], + ), + + ] diff --git a/tests/migrations2/test_migrations_2_no_deps/__init__.py b/tests/migrations2/test_migrations_2_no_deps/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/multiple_database/__init__.py b/tests/multiple_database/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/multiple_database/fixtures/multidb-common.json b/tests/multiple_database/fixtures/multidb-common.json new file mode 100644 index 00000000..33134173 --- /dev/null +++ b/tests/multiple_database/fixtures/multidb-common.json @@ -0,0 +1,10 @@ +[ + { + "pk": 1, + "model": "multiple_database.book", + "fields": { + "title": "The Definitive Guide to Django", + "published": "2009-7-8" + } + } +] \ No newline at end of file diff --git a/tests/multiple_database/fixtures/multidb.default.json b/tests/multiple_database/fixtures/multidb.default.json new file mode 100644 index 00000000..379b18a8 --- /dev/null +++ b/tests/multiple_database/fixtures/multidb.default.json @@ -0,0 +1,26 @@ +[ + { + "pk": 1, + "model": "multiple_database.person", + "fields": { + "name": "Marty Alchin" + } + }, + { + "pk": 2, + "model": "multiple_database.person", + "fields": { + "name": "George Vilches" + } + }, + { + "pk": 2, + "model": "multiple_database.book", + "fields": { + "title": "Pro Django", + "published": "2008-12-16", + "authors": [["Marty Alchin"]], + "editor": ["George Vilches"] + } + } +] diff --git a/tests/multiple_database/fixtures/multidb.other.json b/tests/multiple_database/fixtures/multidb.other.json new file mode 100644 index 00000000..c64f4902 --- /dev/null +++ b/tests/multiple_database/fixtures/multidb.other.json @@ -0,0 +1,26 @@ +[ + { + "pk": 1, + "model": "multiple_database.person", + "fields": { + "name": "Mark Pilgrim" + } + }, + { + "pk": 2, + "model": "multiple_database.person", + "fields": { + "name": "Chris Mills" + } + }, + { + "pk": 2, + "model": "multiple_database.book", + "fields": { + "title": "Dive into Python", + "published": "2009-5-4", + "authors": [["Mark Pilgrim"]], + "editor": ["Chris Mills"] + } + } +] \ No newline at end of file diff --git a/tests/multiple_database/fixtures/pets.json b/tests/multiple_database/fixtures/pets.json new file mode 100644 index 00000000..89756a3e --- /dev/null +++ b/tests/multiple_database/fixtures/pets.json @@ -0,0 +1,18 @@ +[ + { + "pk": 1, + "model": "multiple_database.pet", + "fields": { + "name": "Mr Bigglesworth", + "owner": 1 + } + }, + { + "pk": 2, + "model": "multiple_database.pet", + "fields": { + "name": "Spot", + "owner": 2 + } + } +] \ No newline at end of file diff --git a/tests/multiple_database/models.py b/tests/multiple_database/models.py new file mode 100644 index 00000000..367cd31d --- /dev/null +++ b/tests/multiple_database/models.py @@ -0,0 +1,89 @@ +from django.contrib.auth.models import User +from django.contrib.contenttypes.fields import ( + GenericForeignKey, GenericRelation, +) +from django.contrib.contenttypes.models import ContentType +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +@python_2_unicode_compatible +class Review(models.Model): + source = models.CharField(max_length=100) + content_type = models.ForeignKey(ContentType, models.CASCADE) + object_id = models.PositiveIntegerField() + content_object = GenericForeignKey() + + def __str__(self): + return self.source + + class Meta: + ordering = ('source',) + + +class PersonManager(models.Manager): + def get_by_natural_key(self, name): + return self.get(name=name) + + +@python_2_unicode_compatible +class Person(models.Model): + objects = PersonManager() + name = models.CharField(max_length=100) + + def __str__(self): + return self.name + + class Meta: + ordering = ('name',) + + +# This book manager doesn't do anything interesting; it just +# exists to strip out the 'extra_arg' argument to certain +# calls. This argument is used to establish that the BookManager +# is actually getting used when it should be. +class BookManager(models.Manager): + def create(self, *args, **kwargs): + kwargs.pop('extra_arg', None) + return super(BookManager, self).create(*args, **kwargs) + + def get_or_create(self, *args, **kwargs): + kwargs.pop('extra_arg', None) + return super(BookManager, self).get_or_create(*args, **kwargs) + + +@python_2_unicode_compatible +class Book(models.Model): + objects = BookManager() + title = models.CharField(max_length=100) + published = models.DateField() + authors = models.ManyToManyField(Person) + editor = models.ForeignKey(Person, models.SET_NULL, null=True, related_name='edited') + reviews = GenericRelation(Review) + pages = models.IntegerField(default=100) + + def __str__(self): + return self.title + + class Meta: + ordering = ('title',) + + +@python_2_unicode_compatible +class Pet(models.Model): + name = models.CharField(max_length=100) + owner = models.ForeignKey(Person, models.CASCADE) + + def __str__(self): + return self.name + + class Meta: + ordering = ('name',) + + +class UserProfile(models.Model): + user = models.OneToOneField(User, models.SET_NULL, null=True) + flavor = models.CharField(max_length=100) + + class Meta: + ordering = ('flavor',) diff --git a/tests/multiple_database/routers.py b/tests/multiple_database/routers.py new file mode 100644 index 00000000..e467cf56 --- /dev/null +++ b/tests/multiple_database/routers.py @@ -0,0 +1,62 @@ +from __future__ import unicode_literals + +from django.db import DEFAULT_DB_ALIAS + + +class TestRouter(object): + """ + Vaguely behave like primary/replica, but the databases aren't assumed to + propagate changes. + """ + + def db_for_read(self, model, instance=None, **hints): + if instance: + return instance._state.db or 'other' + return 'other' + + def db_for_write(self, model, **hints): + return DEFAULT_DB_ALIAS + + def allow_relation(self, obj1, obj2, **hints): + return obj1._state.db in ('default', 'other') and obj2._state.db in ('default', 'other') + + def allow_migrate(self, db, app_label, **hints): + return True + + +class AuthRouter(object): + """ + Control all database operations on models in the contrib.auth application. + """ + + def db_for_read(self, model, **hints): + "Point all read operations on auth models to 'default'" + if model._meta.app_label == 'auth': + # We use default here to ensure we can tell the difference + # between a read request and a write request for Auth objects + return 'default' + return None + + def db_for_write(self, model, **hints): + "Point all operations on auth models to 'other'" + if model._meta.app_label == 'auth': + return 'other' + return None + + def allow_relation(self, obj1, obj2, **hints): + "Allow any relation if a model in Auth is involved" + if obj1._meta.app_label == 'auth' or obj2._meta.app_label == 'auth': + return True + return None + + def allow_migrate(self, db, app_label, **hints): + "Make sure the auth app only appears on the 'other' db" + if app_label == 'auth': + return db == 'other' + return None + + +class WriteRouter(object): + # A router that only expresses an opinion on writes + def db_for_write(self, model, **hints): + return 'writer' diff --git a/tests/multiple_database/tests.py b/tests/multiple_database/tests.py new file mode 100644 index 00000000..cc762f04 --- /dev/null +++ b/tests/multiple_database/tests.py @@ -0,0 +1,2042 @@ +from __future__ import unicode_literals + +import datetime +import pickle +from operator import attrgetter + +import django +from django.contrib.auth.models import User +from django.contrib.contenttypes.models import ContentType +from django.core import management +from django.db import DEFAULT_DB_ALIAS, connections, router, transaction +from django.db.models import signals +from django.db.utils import ConnectionRouter +from django.test import SimpleTestCase, TestCase, override_settings +from django.utils.six import StringIO + +from .models import Book, Person, Pet, Review, UserProfile +from .routers import AuthRouter, TestRouter, WriteRouter + + +class QueryTestCase(TestCase): + multi_db = True + + def test_db_selection(self): + "Querysets will use the default database by default" + self.assertEqual(Book.objects.db, DEFAULT_DB_ALIAS) + self.assertEqual(Book.objects.all().db, DEFAULT_DB_ALIAS) + + self.assertEqual(Book.objects.using('other').db, 'other') + + self.assertEqual(Book.objects.db_manager('other').db, 'other') + self.assertEqual(Book.objects.db_manager('other').all().db, 'other') + + def test_default_creation(self): + "Objects created on the default database don't leak onto other databases" + # Create a book on the default database using create() + Book.objects.create(title="Pro Django", published=datetime.date(2008, 12, 16)) + + # Create a book on the default database using a save + dive = Book() + dive.title = "Dive into Python" + dive.published = datetime.date(2009, 5, 4) + dive.save() + + # Book exists on the default database, but not on other database + try: + Book.objects.get(title="Pro Django") + Book.objects.using('default').get(title="Pro Django") + except Book.DoesNotExist: + self.fail('"Pro Django" should exist on default database') + + with self.assertRaises(Book.DoesNotExist): + Book.objects.using('other').get(title="Pro Django") + + try: + Book.objects.get(title="Dive into Python") + Book.objects.using('default').get(title="Dive into Python") + except Book.DoesNotExist: + self.fail('"Dive into Python" should exist on default database') + + with self.assertRaises(Book.DoesNotExist): + Book.objects.using('other').get(title="Dive into Python") + + def test_other_creation(self): + "Objects created on another database don't leak onto the default database" + # Create a book on the second database + Book.objects.using('other').create(title="Pro Django", + published=datetime.date(2008, 12, 16)) + + # Create a book on the default database using a save + dive = Book() + dive.title = "Dive into Python" + dive.published = datetime.date(2009, 5, 4) + dive.save(using='other') + + # Book exists on the default database, but not on other database + try: + Book.objects.using('other').get(title="Pro Django") + except Book.DoesNotExist: + self.fail('"Pro Django" should exist on other database') + + with self.assertRaises(Book.DoesNotExist): + Book.objects.get(title="Pro Django") + with self.assertRaises(Book.DoesNotExist): + Book.objects.using('default').get(title="Pro Django") + + try: + Book.objects.using('other').get(title="Dive into Python") + except Book.DoesNotExist: + self.fail('"Dive into Python" should exist on other database') + + with self.assertRaises(Book.DoesNotExist): + Book.objects.get(title="Dive into Python") + with self.assertRaises(Book.DoesNotExist): + Book.objects.using('default').get(title="Dive into Python") + + def test_refresh(self): + dive = Book(title="Dive into Python", published=datetime.date(2009, 5, 4)) + dive.save(using='other') + dive2 = Book.objects.using('other').get() + dive2.title = "Dive into Python (on default)" + dive2.save(using='default') + dive.refresh_from_db() + self.assertEqual(dive.title, "Dive into Python") + dive.refresh_from_db(using='default') + self.assertEqual(dive.title, "Dive into Python (on default)") + self.assertEqual(dive._state.db, "default") + + def test_basic_queries(self): + "Queries are constrained to a single database" + dive = Book.objects.using('other').create(title="Dive into Python", published=datetime.date(2009, 5, 4)) + + dive = Book.objects.using('other').get(published=datetime.date(2009, 5, 4)) + self.assertEqual(dive.title, "Dive into Python") + with self.assertRaises(Book.DoesNotExist): + Book.objects.using('default').get(published=datetime.date(2009, 5, 4)) + + dive = Book.objects.using('other').get(title__icontains="dive") + self.assertEqual(dive.title, "Dive into Python") + with self.assertRaises(Book.DoesNotExist): + Book.objects.using('default').get(title__icontains="dive") + + dive = Book.objects.using('other').get(title__iexact="dive INTO python") + self.assertEqual(dive.title, "Dive into Python") + with self.assertRaises(Book.DoesNotExist): + Book.objects.using('default').get(title__iexact="dive INTO python") + + dive = Book.objects.using('other').get(published__year=2009) + self.assertEqual(dive.title, "Dive into Python") + self.assertEqual(dive.published, datetime.date(2009, 5, 4)) + with self.assertRaises(Book.DoesNotExist): + Book.objects.using('default').get(published__year=2009) + + years = Book.objects.using('other').dates('published', 'year') + self.assertEqual([o.year for o in years], [2009]) + years = Book.objects.using('default').dates('published', 'year') + self.assertEqual([o.year for o in years], []) + + months = Book.objects.using('other').dates('published', 'month') + self.assertEqual([o.month for o in months], [5]) + months = Book.objects.using('default').dates('published', 'month') + self.assertEqual([o.month for o in months], []) + + def test_m2m_separation(self): + "M2M fields are constrained to a single database" + # Create a book and author on the default database + pro = Book.objects.create(title="Pro Django", + published=datetime.date(2008, 12, 16)) + + marty = Person.objects.create(name="Marty Alchin") + + # Create a book and author on the other database + dive = Book.objects.using('other').create(title="Dive into Python", published=datetime.date(2009, 5, 4)) + + mark = Person.objects.using('other').create(name="Mark Pilgrim") + + # Save the author relations + pro.authors.set([marty]) + dive.authors.set([mark]) + + # Inspect the m2m tables directly. + # There should be 1 entry in each database + self.assertEqual(Book.authors.through.objects.using('default').count(), 1) + self.assertEqual(Book.authors.through.objects.using('other').count(), 1) + + # Queries work across m2m joins + self.assertEqual( + list(Book.objects.using('default').filter(authors__name='Marty Alchin').values_list('title', flat=True)), + ['Pro Django'] + ) + self.assertEqual( + list(Book.objects.using('other').filter(authors__name='Marty Alchin').values_list('title', flat=True)), + [] + ) + + self.assertEqual( + list(Book.objects.using('default').filter(authors__name='Mark Pilgrim').values_list('title', flat=True)), + [] + ) + self.assertEqual( + list(Book.objects.using('other').filter(authors__name='Mark Pilgrim').values_list('title', flat=True)), + ['Dive into Python'] + ) + + # Reget the objects to clear caches + dive = Book.objects.using('other').get(title="Dive into Python") + mark = Person.objects.using('other').get(name="Mark Pilgrim") + + # Retrieve related object by descriptor. Related objects should be database-bound + self.assertEqual(list(dive.authors.all().values_list('name', flat=True)), ['Mark Pilgrim']) + + self.assertEqual(list(mark.book_set.all().values_list('title', flat=True)), ['Dive into Python']) + + def test_m2m_forward_operations(self): + "M2M forward manipulations are all constrained to a single DB" + # Create a book and author on the other database + dive = Book.objects.using('other').create(title="Dive into Python", published=datetime.date(2009, 5, 4)) + mark = Person.objects.using('other').create(name="Mark Pilgrim") + + # Save the author relations + dive.authors.set([mark]) + + # Add a second author + john = Person.objects.using('other').create(name="John Smith") + self.assertEqual( + list(Book.objects.using('other').filter(authors__name='John Smith').values_list('title', flat=True)), + [] + ) + + dive.authors.add(john) + self.assertEqual( + list(Book.objects.using('other').filter(authors__name='Mark Pilgrim').values_list('title', flat=True)), + ['Dive into Python'] + ) + self.assertEqual( + list(Book.objects.using('other').filter(authors__name='John Smith').values_list('title', flat=True)), + ['Dive into Python'] + ) + + # Remove the second author + dive.authors.remove(john) + self.assertEqual( + list(Book.objects.using('other').filter(authors__name='Mark Pilgrim').values_list('title', flat=True)), + ['Dive into Python'] + ) + self.assertEqual( + list(Book.objects.using('other').filter(authors__name='John Smith').values_list('title', flat=True)), + [] + ) + + # Clear all authors + dive.authors.clear() + self.assertEqual( + list(Book.objects.using('other').filter(authors__name='Mark Pilgrim').values_list('title', flat=True)), + [] + ) + self.assertEqual( + list(Book.objects.using('other').filter(authors__name='John Smith').values_list('title', flat=True)), + [] + ) + + # Create an author through the m2m interface + dive.authors.create(name='Jane Brown') + self.assertEqual( + list(Book.objects.using('other').filter(authors__name='Mark Pilgrim').values_list('title', flat=True)), + [] + ) + self.assertEqual( + list(Book.objects.using('other').filter(authors__name='Jane Brown').values_list('title', flat=True)), + ['Dive into Python'] + ) + + def test_m2m_reverse_operations(self): + "M2M reverse manipulations are all constrained to a single DB" + # Create a book and author on the other database + dive = Book.objects.using('other').create(title="Dive into Python", published=datetime.date(2009, 5, 4)) + mark = Person.objects.using('other').create(name="Mark Pilgrim") + + # Save the author relations + dive.authors.set([mark]) + + # Create a second book on the other database + grease = Book.objects.using('other').create(title="Greasemonkey Hacks", published=datetime.date(2005, 11, 1)) + + # Add a books to the m2m + mark.book_set.add(grease) + self.assertEqual( + list(Person.objects.using('other').filter(book__title='Dive into Python').values_list('name', flat=True)), + ['Mark Pilgrim'] + ) + self.assertEqual( + list( + Person.objects.using('other').filter(book__title='Greasemonkey Hacks').values_list('name', flat=True) + ), + ['Mark Pilgrim'] + ) + + # Remove a book from the m2m + mark.book_set.remove(grease) + self.assertEqual( + list(Person.objects.using('other').filter(book__title='Dive into Python').values_list('name', flat=True)), + ['Mark Pilgrim'] + ) + self.assertEqual( + list( + Person.objects.using('other').filter(book__title='Greasemonkey Hacks').values_list('name', flat=True) + ), + [] + ) + + # Clear the books associated with mark + mark.book_set.clear() + self.assertEqual( + list(Person.objects.using('other').filter(book__title='Dive into Python').values_list('name', flat=True)), + [] + ) + self.assertEqual( + list( + Person.objects.using('other').filter(book__title='Greasemonkey Hacks').values_list('name', flat=True) + ), + [] + ) + + # Create a book through the m2m interface + mark.book_set.create(title="Dive into HTML5", published=datetime.date(2020, 1, 1)) + self.assertEqual( + list(Person.objects.using('other').filter(book__title='Dive into Python').values_list('name', flat=True)), + [] + ) + self.assertEqual( + list(Person.objects.using('other').filter(book__title='Dive into HTML5').values_list('name', flat=True)), + ['Mark Pilgrim'] + ) + + def test_m2m_cross_database_protection(self): + "Operations that involve sharing M2M objects across databases raise an error" + # Create a book and author on the default database + pro = Book.objects.create(title="Pro Django", published=datetime.date(2008, 12, 16)) + + marty = Person.objects.create(name="Marty Alchin") + + # Create a book and author on the other database + dive = Book.objects.using('other').create(title="Dive into Python", published=datetime.date(2009, 5, 4)) + + mark = Person.objects.using('other').create(name="Mark Pilgrim") + # Set a foreign key set with an object from a different database + with self.assertRaises(ValueError): + with transaction.atomic(using='default'): + marty.edited.set([pro, dive]) + + # Add to an m2m with an object from a different database + with self.assertRaises(ValueError): + with transaction.atomic(using='default'): + marty.book_set.add(dive) + + # Set a m2m with an object from a different database + with self.assertRaises(ValueError): + with transaction.atomic(using='default'): + marty.book_set.set([pro, dive]) + + # Add to a reverse m2m with an object from a different database + with self.assertRaises(ValueError): + with transaction.atomic(using='other'): + dive.authors.add(marty) + + # Set a reverse m2m with an object from a different database + with self.assertRaises(ValueError): + with transaction.atomic(using='other'): + dive.authors.set([mark, marty]) + + def test_m2m_deletion(self): + "Cascaded deletions of m2m relations issue queries on the right database" + # Create a book and author on the other database + dive = Book.objects.using('other').create(title="Dive into Python", published=datetime.date(2009, 5, 4)) + mark = Person.objects.using('other').create(name="Mark Pilgrim") + dive.authors.set([mark]) + + # Check the initial state + self.assertEqual(Person.objects.using('default').count(), 0) + self.assertEqual(Book.objects.using('default').count(), 0) + self.assertEqual(Book.authors.through.objects.using('default').count(), 0) + + self.assertEqual(Person.objects.using('other').count(), 1) + self.assertEqual(Book.objects.using('other').count(), 1) + self.assertEqual(Book.authors.through.objects.using('other').count(), 1) + + # Delete the object on the other database + dive.delete(using='other') + + self.assertEqual(Person.objects.using('default').count(), 0) + self.assertEqual(Book.objects.using('default').count(), 0) + self.assertEqual(Book.authors.through.objects.using('default').count(), 0) + + # The person still exists ... + self.assertEqual(Person.objects.using('other').count(), 1) + # ... but the book has been deleted + self.assertEqual(Book.objects.using('other').count(), 0) + # ... and the relationship object has also been deleted. + self.assertEqual(Book.authors.through.objects.using('other').count(), 0) + + # Now try deletion in the reverse direction. Set up the relation again + dive = Book.objects.using('other').create(title="Dive into Python", published=datetime.date(2009, 5, 4)) + dive.authors.set([mark]) + + # Check the initial state + self.assertEqual(Person.objects.using('default').count(), 0) + self.assertEqual(Book.objects.using('default').count(), 0) + self.assertEqual(Book.authors.through.objects.using('default').count(), 0) + + self.assertEqual(Person.objects.using('other').count(), 1) + self.assertEqual(Book.objects.using('other').count(), 1) + self.assertEqual(Book.authors.through.objects.using('other').count(), 1) + + # Delete the object on the other database + mark.delete(using='other') + + self.assertEqual(Person.objects.using('default').count(), 0) + self.assertEqual(Book.objects.using('default').count(), 0) + self.assertEqual(Book.authors.through.objects.using('default').count(), 0) + + # The person has been deleted ... + self.assertEqual(Person.objects.using('other').count(), 0) + # ... but the book still exists + self.assertEqual(Book.objects.using('other').count(), 1) + # ... and the relationship object has been deleted. + self.assertEqual(Book.authors.through.objects.using('other').count(), 0) + + def test_foreign_key_separation(self): + "FK fields are constrained to a single database" + # Create a book and author on the default database + pro = Book.objects.create(title="Pro Django", published=datetime.date(2008, 12, 16)) + + george = Person.objects.create(name="George Vilches") + + # Create a book and author on the other database + dive = Book.objects.using('other').create(title="Dive into Python", published=datetime.date(2009, 5, 4)) + chris = Person.objects.using('other').create(name="Chris Mills") + + # Save the author's favorite books + pro.editor = george + pro.save() + + dive.editor = chris + dive.save() + + pro = Book.objects.using('default').get(title="Pro Django") + self.assertEqual(pro.editor.name, "George Vilches") + + dive = Book.objects.using('other').get(title="Dive into Python") + self.assertEqual(dive.editor.name, "Chris Mills") + + # Queries work across foreign key joins + self.assertEqual( + list(Person.objects.using('default').filter(edited__title='Pro Django').values_list('name', flat=True)), + ['George Vilches'] + ) + self.assertEqual( + list(Person.objects.using('other').filter(edited__title='Pro Django').values_list('name', flat=True)), + [] + ) + + self.assertEqual( + list( + Person.objects.using('default').filter(edited__title='Dive into Python').values_list('name', flat=True) + ), + [] + ) + self.assertEqual( + list( + Person.objects.using('other').filter(edited__title='Dive into Python').values_list('name', flat=True) + ), + ['Chris Mills'] + ) + + # Reget the objects to clear caches + chris = Person.objects.using('other').get(name="Chris Mills") + dive = Book.objects.using('other').get(title="Dive into Python") + + # Retrieve related object by descriptor. Related objects should be database-bound + self.assertEqual(list(chris.edited.values_list('title', flat=True)), ['Dive into Python']) + + def test_foreign_key_reverse_operations(self): + "FK reverse manipulations are all constrained to a single DB" + dive = Book.objects.using('other').create(title="Dive into Python", published=datetime.date(2009, 5, 4)) + chris = Person.objects.using('other').create(name="Chris Mills") + + # Save the author relations + dive.editor = chris + dive.save() + + # Add a second book edited by chris + html5 = Book.objects.using('other').create(title="Dive into HTML5", published=datetime.date(2010, 3, 15)) + self.assertEqual( + list(Person.objects.using('other').filter(edited__title='Dive into HTML5').values_list('name', flat=True)), + [] + ) + + chris.edited.add(html5) + self.assertEqual( + list(Person.objects.using('other').filter(edited__title='Dive into HTML5').values_list('name', flat=True)), + ['Chris Mills'] + ) + self.assertEqual( + list( + Person.objects.using('other').filter(edited__title='Dive into Python').values_list('name', flat=True) + ), + ['Chris Mills'] + ) + + # Remove the second editor + chris.edited.remove(html5) + self.assertEqual( + list(Person.objects.using('other').filter(edited__title='Dive into HTML5').values_list('name', flat=True)), + [] + ) + self.assertEqual( + list( + Person.objects.using('other').filter(edited__title='Dive into Python').values_list('name', flat=True) + ), + ['Chris Mills'] + ) + + # Clear all edited books + chris.edited.clear() + self.assertEqual( + list(Person.objects.using('other').filter(edited__title='Dive into HTML5').values_list('name', flat=True)), + [] + ) + self.assertEqual( + list( + Person.objects.using('other').filter(edited__title='Dive into Python').values_list('name', flat=True) + ), + [] + ) + + # Create an author through the m2m interface + chris.edited.create(title='Dive into Water', published=datetime.date(2010, 3, 15)) + self.assertEqual( + list(Person.objects.using('other').filter(edited__title='Dive into HTML5').values_list('name', flat=True)), + [] + ) + self.assertEqual( + list(Person.objects.using('other').filter(edited__title='Dive into Water').values_list('name', flat=True)), + ['Chris Mills'] + ) + self.assertEqual( + list( + Person.objects.using('other').filter(edited__title='Dive into Python').values_list('name', flat=True) + ), + [] + ) + + def test_foreign_key_cross_database_protection(self): + "Operations that involve sharing FK objects across databases raise an error" + # Create a book and author on the default database + pro = Book.objects.create(title="Pro Django", published=datetime.date(2008, 12, 16)) + marty = Person.objects.create(name="Marty Alchin") + + # Create a book and author on the other database + dive = Book.objects.using('other').create(title="Dive into Python", published=datetime.date(2009, 5, 4)) + + # Set a foreign key with an object from a different database + with self.assertRaises(ValueError): + dive.editor = marty + + # Set a foreign key set with an object from a different database + with self.assertRaises(ValueError): + with transaction.atomic(using='default'): + marty.edited.set([pro, dive]) + + # Add to a foreign key set with an object from a different database + with self.assertRaises(ValueError): + with transaction.atomic(using='default'): + marty.edited.add(dive) + + def test_foreign_key_deletion(self): + "Cascaded deletions of Foreign Key relations issue queries on the right database" + mark = Person.objects.using('other').create(name="Mark Pilgrim") + Pet.objects.using('other').create(name="Fido", owner=mark) + + # Check the initial state + self.assertEqual(Person.objects.using('default').count(), 0) + self.assertEqual(Pet.objects.using('default').count(), 0) + + self.assertEqual(Person.objects.using('other').count(), 1) + self.assertEqual(Pet.objects.using('other').count(), 1) + + # Delete the person object, which will cascade onto the pet + mark.delete(using='other') + + self.assertEqual(Person.objects.using('default').count(), 0) + self.assertEqual(Pet.objects.using('default').count(), 0) + + # Both the pet and the person have been deleted from the right database + self.assertEqual(Person.objects.using('other').count(), 0) + self.assertEqual(Pet.objects.using('other').count(), 0) + + def test_foreign_key_validation(self): + "ForeignKey.validate() uses the correct database" + mickey = Person.objects.using('other').create(name="Mickey") + pluto = Pet.objects.using('other').create(name="Pluto", owner=mickey) + self.assertIsNone(pluto.full_clean()) + + # Any router that accesses `model` in db_for_read() works here. + @override_settings(DATABASE_ROUTERS=[AuthRouter()]) + def test_foreign_key_validation_with_router(self): + """ + ForeignKey.validate() passes `model` to db_for_read() even if + model_instance=None. + """ + if django.VERSION < (1, 11, 0): + self.skipTest("TODO fix AttributeError: type object 'NoneType' has no attribute '_meta'") + mickey = Person.objects.create(name="Mickey") + owner_field = Pet._meta.get_field('owner') + self.assertEqual(owner_field.clean(mickey.pk, None), mickey.pk) + + def test_o2o_separation(self): + "OneToOne fields are constrained to a single database" + # Create a user and profile on the default database + alice = User.objects.db_manager('default').create_user('alice', 'alice@example.com') + alice_profile = UserProfile.objects.using('default').create(user=alice, flavor='chocolate') + + # Create a user and profile on the other database + bob = User.objects.db_manager('other').create_user('bob', 'bob@example.com') + bob_profile = UserProfile.objects.using('other').create(user=bob, flavor='crunchy frog') + + # Retrieve related objects; queries should be database constrained + alice = User.objects.using('default').get(username="alice") + self.assertEqual(alice.userprofile.flavor, "chocolate") + + bob = User.objects.using('other').get(username="bob") + self.assertEqual(bob.userprofile.flavor, "crunchy frog") + + # Queries work across joins + self.assertEqual( + list( + User.objects.using('default') + .filter(userprofile__flavor='chocolate').values_list('username', flat=True) + ), + ['alice'] + ) + self.assertEqual( + list( + User.objects.using('other') + .filter(userprofile__flavor='chocolate').values_list('username', flat=True) + ), + [] + ) + + self.assertEqual( + list( + User.objects.using('default') + .filter(userprofile__flavor='crunchy frog').values_list('username', flat=True) + ), + [] + ) + self.assertEqual( + list( + User.objects.using('other') + .filter(userprofile__flavor='crunchy frog').values_list('username', flat=True) + ), + ['bob'] + ) + + # Reget the objects to clear caches + alice_profile = UserProfile.objects.using('default').get(flavor='chocolate') + bob_profile = UserProfile.objects.using('other').get(flavor='crunchy frog') + + # Retrieve related object by descriptor. Related objects should be database-bound + self.assertEqual(alice_profile.user.username, 'alice') + self.assertEqual(bob_profile.user.username, 'bob') + + def test_o2o_cross_database_protection(self): + "Operations that involve sharing FK objects across databases raise an error" + # Create a user and profile on the default database + alice = User.objects.db_manager('default').create_user('alice', 'alice@example.com') + + # Create a user and profile on the other database + bob = User.objects.db_manager('other').create_user('bob', 'bob@example.com') + + # Set a one-to-one relation with an object from a different database + alice_profile = UserProfile.objects.using('default').create(user=alice, flavor='chocolate') + with self.assertRaises(ValueError): + bob.userprofile = alice_profile + + # BUT! if you assign a FK object when the base object hasn't + # been saved yet, you implicitly assign the database for the + # base object. + bob_profile = UserProfile.objects.using('other').create(user=bob, flavor='crunchy frog') + + new_bob_profile = UserProfile(flavor="spring surprise") + + # assigning a profile requires an explicit pk as the object isn't saved + charlie = User(pk=51, username='charlie', email='charlie@example.com') + charlie.set_unusable_password() + + # initially, no db assigned + self.assertIsNone(new_bob_profile._state.db) + self.assertIsNone(charlie._state.db) + + # old object comes from 'other', so the new object is set to use 'other'... + new_bob_profile.user = bob + charlie.userprofile = bob_profile + self.assertEqual(new_bob_profile._state.db, 'other') + self.assertEqual(charlie._state.db, 'other') + + # ... but it isn't saved yet + self.assertEqual(list(User.objects.using('other').values_list('username', flat=True)), ['bob']) + self.assertEqual(list(UserProfile.objects.using('other').values_list('flavor', flat=True)), ['crunchy frog']) + + # When saved (no using required), new objects goes to 'other' + charlie.save() + bob_profile.save() + new_bob_profile.save() + self.assertEqual(list(User.objects.using('default').values_list('username', flat=True)), ['alice']) + self.assertEqual(list(User.objects.using('other').values_list('username', flat=True)), ['bob', 'charlie']) + self.assertEqual(list(UserProfile.objects.using('default').values_list('flavor', flat=True)), ['chocolate']) + self.assertEqual( + list(UserProfile.objects.using('other').values_list('flavor', flat=True)), + ['crunchy frog', 'spring surprise'] + ) + + # This also works if you assign the O2O relation in the constructor + denise = User.objects.db_manager('other').create_user('denise', 'denise@example.com') + denise_profile = UserProfile(flavor="tofu", user=denise) + + self.assertEqual(denise_profile._state.db, 'other') + # ... but it isn't saved yet + self.assertEqual(list(UserProfile.objects.using('default').values_list('flavor', flat=True)), ['chocolate']) + self.assertEqual( + list(UserProfile.objects.using('other').values_list('flavor', flat=True)), + ['crunchy frog', 'spring surprise'] + ) + + # When saved, the new profile goes to 'other' + denise_profile.save() + self.assertEqual(list(UserProfile.objects.using('default').values_list('flavor', flat=True)), ['chocolate']) + self.assertEqual( + list(UserProfile.objects.using('other').values_list('flavor', flat=True)), + ['crunchy frog', 'spring surprise', 'tofu'] + ) + + def test_generic_key_separation(self): + "Generic fields are constrained to a single database" + # Create a book and author on the default database + pro = Book.objects.create(title="Pro Django", published=datetime.date(2008, 12, 16)) + review1 = Review.objects.create(source="Python Monthly", content_object=pro) + + # Create a book and author on the other database + dive = Book.objects.using('other').create(title="Dive into Python", published=datetime.date(2009, 5, 4)) + + review2 = Review.objects.using('other').create(source="Python Weekly", content_object=dive) + + review1 = Review.objects.using('default').get(source="Python Monthly") + self.assertEqual(review1.content_object.title, "Pro Django") + + review2 = Review.objects.using('other').get(source="Python Weekly") + self.assertEqual(review2.content_object.title, "Dive into Python") + + # Reget the objects to clear caches + dive = Book.objects.using('other').get(title="Dive into Python") + + # Retrieve related object by descriptor. Related objects should be database-bound + self.assertEqual(list(dive.reviews.all().values_list('source', flat=True)), ['Python Weekly']) + + def test_generic_key_reverse_operations(self): + "Generic reverse manipulations are all constrained to a single DB" + dive = Book.objects.using('other').create(title="Dive into Python", published=datetime.date(2009, 5, 4)) + temp = Book.objects.using('other').create(title="Temp", published=datetime.date(2009, 5, 4)) + review1 = Review.objects.using('other').create(source="Python Weekly", content_object=dive) + review2 = Review.objects.using('other').create(source="Python Monthly", content_object=temp) + + self.assertEqual( + list(Review.objects.using('default').filter(object_id=dive.pk).values_list('source', flat=True)), + [] + ) + self.assertEqual( + list(Review.objects.using('other').filter(object_id=dive.pk).values_list('source', flat=True)), + ['Python Weekly'] + ) + + # Add a second review + dive.reviews.add(review2) + self.assertEqual( + list(Review.objects.using('default').filter(object_id=dive.pk).values_list('source', flat=True)), + [] + ) + self.assertEqual( + list(Review.objects.using('other').filter(object_id=dive.pk).values_list('source', flat=True)), + ['Python Monthly', 'Python Weekly'] + ) + + # Remove the second author + dive.reviews.remove(review1) + self.assertEqual( + list(Review.objects.using('default').filter(object_id=dive.pk).values_list('source', flat=True)), + [] + ) + self.assertEqual( + list(Review.objects.using('other').filter(object_id=dive.pk).values_list('source', flat=True)), + ['Python Monthly'] + ) + + # Clear all reviews + dive.reviews.clear() + self.assertEqual( + list(Review.objects.using('default').filter(object_id=dive.pk).values_list('source', flat=True)), + [] + ) + self.assertEqual( + list(Review.objects.using('other').filter(object_id=dive.pk).values_list('source', flat=True)), + [] + ) + + # Create an author through the generic interface + dive.reviews.create(source='Python Daily') + self.assertEqual( + list(Review.objects.using('default').filter(object_id=dive.pk).values_list('source', flat=True)), + [] + ) + self.assertEqual( + list(Review.objects.using('other').filter(object_id=dive.pk).values_list('source', flat=True)), + ['Python Daily'] + ) + + def test_generic_key_cross_database_protection(self): + "Operations that involve sharing generic key objects across databases raise an error" + # Create a book and author on the default database + pro = Book.objects.create(title="Pro Django", published=datetime.date(2008, 12, 16)) + review1 = Review.objects.create(source="Python Monthly", content_object=pro) + + # Create a book and author on the other database + dive = Book.objects.using('other').create(title="Dive into Python", published=datetime.date(2009, 5, 4)) + + Review.objects.using('other').create(source="Python Weekly", content_object=dive) + + # Set a foreign key with an object from a different database + with self.assertRaises(ValueError): + review1.content_object = dive + + # Add to a foreign key set with an object from a different database + with self.assertRaises(ValueError): + with transaction.atomic(using='other'): + dive.reviews.add(review1) + + # BUT! if you assign a FK object when the base object hasn't + # been saved yet, you implicitly assign the database for the + # base object. + review3 = Review(source="Python Daily") + # initially, no db assigned + self.assertIsNone(review3._state.db) + + # Dive comes from 'other', so review3 is set to use 'other'... + review3.content_object = dive + self.assertEqual(review3._state.db, 'other') + # ... but it isn't saved yet + self.assertEqual( + list(Review.objects.using('default').filter(object_id=pro.pk).values_list('source', flat=True)), + ['Python Monthly'] + ) + self.assertEqual( + list(Review.objects.using('other').filter(object_id=dive.pk).values_list('source', flat=True)), + ['Python Weekly'] + ) + + # When saved, John goes to 'other' + review3.save() + self.assertEqual( + list(Review.objects.using('default').filter(object_id=pro.pk).values_list('source', flat=True)), + ['Python Monthly'] + ) + self.assertEqual( + list(Review.objects.using('other').filter(object_id=dive.pk).values_list('source', flat=True)), + ['Python Daily', 'Python Weekly'] + ) + + def test_generic_key_deletion(self): + "Cascaded deletions of Generic Key relations issue queries on the right database" + dive = Book.objects.using('other').create(title="Dive into Python", published=datetime.date(2009, 5, 4)) + Review.objects.using('other').create(source="Python Weekly", content_object=dive) + + # Check the initial state + self.assertEqual(Book.objects.using('default').count(), 0) + self.assertEqual(Review.objects.using('default').count(), 0) + + self.assertEqual(Book.objects.using('other').count(), 1) + self.assertEqual(Review.objects.using('other').count(), 1) + + # Delete the Book object, which will cascade onto the pet + dive.delete(using='other') + + self.assertEqual(Book.objects.using('default').count(), 0) + self.assertEqual(Review.objects.using('default').count(), 0) + + # Both the pet and the person have been deleted from the right database + self.assertEqual(Book.objects.using('other').count(), 0) + self.assertEqual(Review.objects.using('other').count(), 0) + + def test_ordering(self): + "get_next_by_XXX commands stick to a single database" + Book.objects.create(title="Pro Django", published=datetime.date(2008, 12, 16)) + dive = Book.objects.using('other').create(title="Dive into Python", published=datetime.date(2009, 5, 4)) + learn = Book.objects.using('other').create(title="Learning Python", published=datetime.date(2008, 7, 16)) + + self.assertEqual(learn.get_next_by_published().title, "Dive into Python") + self.assertEqual(dive.get_previous_by_published().title, "Learning Python") + + def test_raw(self): + "test the raw() method across databases" + dive = Book.objects.using('other').create(title="Dive into Python", published=datetime.date(2009, 5, 4)) + val = Book.objects.db_manager("other").raw('SELECT id FROM multiple_database_book') + self.assertQuerysetEqual(val, [dive.pk], attrgetter("pk")) + + val = Book.objects.raw('SELECT id FROM multiple_database_book').using('other') + self.assertQuerysetEqual(val, [dive.pk], attrgetter("pk")) + + def test_select_related(self): + "Database assignment is retained if an object is retrieved with select_related()" + # Create a book and author on the other database + mark = Person.objects.using('other').create(name="Mark Pilgrim") + Book.objects.using('other').create( + title="Dive into Python", + published=datetime.date(2009, 5, 4), + editor=mark, + ) + + # Retrieve the Person using select_related() + book = Book.objects.using('other').select_related('editor').get(title="Dive into Python") + + # The editor instance should have a db state + self.assertEqual(book.editor._state.db, 'other') + + def test_subquery(self): + """Make sure as_sql works with subqueries and primary/replica.""" + sub = Person.objects.using('other').filter(name='fff') + qs = Book.objects.filter(editor__in=sub) + + # When you call __str__ on the query object, it doesn't know about using + # so it falls back to the default. If the subquery explicitly uses a + # different database, an error should be raised. + with self.assertRaises(ValueError): + str(qs.query) + + # Evaluating the query shouldn't work, either + with self.assertRaises(ValueError): + for obj in qs: + pass + + def test_related_manager(self): + "Related managers return managers, not querysets" + mark = Person.objects.using('other').create(name="Mark Pilgrim") + + # extra_arg is removed by the BookManager's implementation of + # create(); but the BookManager's implementation won't get called + # unless edited returns a Manager, not a queryset + mark.book_set.create(title="Dive into Python", published=datetime.date(2009, 5, 4), extra_arg=True) + mark.book_set.get_or_create(title="Dive into Python", published=datetime.date(2009, 5, 4), extra_arg=True) + mark.edited.create(title="Dive into Water", published=datetime.date(2009, 5, 4), extra_arg=True) + mark.edited.get_or_create(title="Dive into Water", published=datetime.date(2009, 5, 4), extra_arg=True) + + +class ConnectionRouterTestCase(SimpleTestCase): + @override_settings(DATABASE_ROUTERS=[ + 'multiple_database.tests.TestRouter', + 'multiple_database.tests.WriteRouter']) + def test_router_init_default(self): + connection_router = ConnectionRouter() + self.assertListEqual([r.__class__.__name__ for r in connection_router.routers], ['TestRouter', 'WriteRouter']) + + def test_router_init_arg(self): + connection_router = ConnectionRouter([ + 'multiple_database.tests.TestRouter', + 'multiple_database.tests.WriteRouter' + ]) + self.assertListEqual([r.__class__.__name__ for r in connection_router.routers], ['TestRouter', 'WriteRouter']) + + # Init with instances instead of strings + connection_router = ConnectionRouter([TestRouter(), WriteRouter()]) + self.assertListEqual([r.__class__.__name__ for r in connection_router.routers], ['TestRouter', 'WriteRouter']) + + +# Make the 'other' database appear to be a replica of the 'default' +@override_settings(DATABASE_ROUTERS=[TestRouter()]) +class RouterTestCase(TestCase): + multi_db = True + + def test_db_selection(self): + "Querysets obey the router for db suggestions" + self.assertEqual(Book.objects.db, 'other') + self.assertEqual(Book.objects.all().db, 'other') + + self.assertEqual(Book.objects.using('default').db, 'default') + + self.assertEqual(Book.objects.db_manager('default').db, 'default') + self.assertEqual(Book.objects.db_manager('default').all().db, 'default') + + def test_migrate_selection(self): + "Synchronization behavior is predictable" + + self.assertTrue(router.allow_migrate_model('default', User)) + self.assertTrue(router.allow_migrate_model('default', Book)) + + self.assertTrue(router.allow_migrate_model('other', User)) + self.assertTrue(router.allow_migrate_model('other', Book)) + + with override_settings(DATABASE_ROUTERS=[TestRouter(), AuthRouter()]): + # Add the auth router to the chain. TestRouter is a universal + # synchronizer, so it should have no effect. + self.assertTrue(router.allow_migrate_model('default', User)) + self.assertTrue(router.allow_migrate_model('default', Book)) + + self.assertTrue(router.allow_migrate_model('other', User)) + self.assertTrue(router.allow_migrate_model('other', Book)) + + with override_settings(DATABASE_ROUTERS=[AuthRouter(), TestRouter()]): + # Now check what happens if the router order is reversed. + self.assertFalse(router.allow_migrate_model('default', User)) + self.assertTrue(router.allow_migrate_model('default', Book)) + + self.assertTrue(router.allow_migrate_model('other', User)) + self.assertTrue(router.allow_migrate_model('other', Book)) + + def test_partial_router(self): + "A router can choose to implement a subset of methods" + dive = Book.objects.using('other').create(title="Dive into Python", + published=datetime.date(2009, 5, 4)) + + # First check the baseline behavior. + + self.assertEqual(router.db_for_read(User), 'other') + self.assertEqual(router.db_for_read(Book), 'other') + + self.assertEqual(router.db_for_write(User), 'default') + self.assertEqual(router.db_for_write(Book), 'default') + + self.assertTrue(router.allow_relation(dive, dive)) + + self.assertTrue(router.allow_migrate_model('default', User)) + self.assertTrue(router.allow_migrate_model('default', Book)) + + with override_settings(DATABASE_ROUTERS=[WriteRouter(), AuthRouter(), TestRouter()]): + self.assertEqual(router.db_for_read(User), 'default') + self.assertEqual(router.db_for_read(Book), 'other') + + self.assertEqual(router.db_for_write(User), 'writer') + self.assertEqual(router.db_for_write(Book), 'writer') + + self.assertTrue(router.allow_relation(dive, dive)) + + self.assertFalse(router.allow_migrate_model('default', User)) + self.assertTrue(router.allow_migrate_model('default', Book)) + + def test_database_routing(self): + marty = Person.objects.using('default').create(name="Marty Alchin") + pro = Book.objects.using('default').create(title="Pro Django", + published=datetime.date(2008, 12, 16), + editor=marty) + pro.authors.set([marty]) + + # Create a book and author on the other database + Book.objects.using('other').create(title="Dive into Python", published=datetime.date(2009, 5, 4)) + + # An update query will be routed to the default database + Book.objects.filter(title='Pro Django').update(pages=200) + + with self.assertRaises(Book.DoesNotExist): + # By default, the get query will be directed to 'other' + Book.objects.get(title='Pro Django') + + # But the same query issued explicitly at a database will work. + pro = Book.objects.using('default').get(title='Pro Django') + + # The update worked. + self.assertEqual(pro.pages, 200) + + # An update query with an explicit using clause will be routed + # to the requested database. + Book.objects.using('other').filter(title='Dive into Python').update(pages=300) + self.assertEqual(Book.objects.get(title='Dive into Python').pages, 300) + + # Related object queries stick to the same database + # as the original object, regardless of the router + self.assertEqual(list(pro.authors.values_list('name', flat=True)), ['Marty Alchin']) + self.assertEqual(pro.editor.name, 'Marty Alchin') + + # get_or_create is a special case. The get needs to be targeted at + # the write database in order to avoid potential transaction + # consistency problems + book, created = Book.objects.get_or_create(title="Pro Django") + self.assertFalse(created) + + book, created = Book.objects.get_or_create(title="Dive Into Python", + defaults={'published': datetime.date(2009, 5, 4)}) + self.assertTrue(created) + + # Check the head count of objects + self.assertEqual(Book.objects.using('default').count(), 2) + self.assertEqual(Book.objects.using('other').count(), 1) + # If a database isn't specified, the read database is used + self.assertEqual(Book.objects.count(), 1) + + # A delete query will also be routed to the default database + Book.objects.filter(pages__gt=150).delete() + + # The default database has lost the book. + self.assertEqual(Book.objects.using('default').count(), 1) + self.assertEqual(Book.objects.using('other').count(), 1) + + def test_invalid_set_foreign_key_assignment(self): + marty = Person.objects.using('default').create(name="Marty Alchin") + dive = Book.objects.using('other').create( + title="Dive into Python", + published=datetime.date(2009, 5, 4), + ) + # Set a foreign key set with an object from a different database + msg = " instance isn't saved. Use bulk=False or save the object first." + with self.assertRaisesMessage(ValueError, msg): + marty.edited.set([dive]) + + def test_foreign_key_cross_database_protection(self): + "Foreign keys can cross databases if they two databases have a common source" + # Create a book and author on the default database + pro = Book.objects.using('default').create(title="Pro Django", + published=datetime.date(2008, 12, 16)) + + marty = Person.objects.using('default').create(name="Marty Alchin") + + # Create a book and author on the other database + dive = Book.objects.using('other').create(title="Dive into Python", + published=datetime.date(2009, 5, 4)) + + mark = Person.objects.using('other').create(name="Mark Pilgrim") + + # Set a foreign key with an object from a different database + dive.editor = marty + + # Database assignments of original objects haven't changed... + self.assertEqual(marty._state.db, 'default') + self.assertEqual(pro._state.db, 'default') + self.assertEqual(dive._state.db, 'other') + self.assertEqual(mark._state.db, 'other') + + # ... but they will when the affected object is saved. + dive.save() + self.assertEqual(dive._state.db, 'default') + + # ...and the source database now has a copy of any object saved + Book.objects.using('default').get(title='Dive into Python').delete() + + # This isn't a real primary/replica database, so restore the original from other + dive = Book.objects.using('other').get(title='Dive into Python') + self.assertEqual(dive._state.db, 'other') + + # Set a foreign key set with an object from a different database + marty.edited.set([pro, dive], bulk=False) + + # Assignment implies a save, so database assignments of original objects have changed... + self.assertEqual(marty._state.db, 'default') + self.assertEqual(pro._state.db, 'default') + self.assertEqual(dive._state.db, 'default') + self.assertEqual(mark._state.db, 'other') + + # ...and the source database now has a copy of any object saved + Book.objects.using('default').get(title='Dive into Python').delete() + + # This isn't a real primary/replica database, so restore the original from other + dive = Book.objects.using('other').get(title='Dive into Python') + self.assertEqual(dive._state.db, 'other') + + # Add to a foreign key set with an object from a different database + marty.edited.add(dive, bulk=False) + + # Add implies a save, so database assignments of original objects have changed... + self.assertEqual(marty._state.db, 'default') + self.assertEqual(pro._state.db, 'default') + self.assertEqual(dive._state.db, 'default') + self.assertEqual(mark._state.db, 'other') + + # ...and the source database now has a copy of any object saved + Book.objects.using('default').get(title='Dive into Python').delete() + + # This isn't a real primary/replica database, so restore the original from other + dive = Book.objects.using('other').get(title='Dive into Python') + + # If you assign a FK object when the base object hasn't + # been saved yet, you implicitly assign the database for the + # base object. + chris = Person(name="Chris Mills") + html5 = Book(title="Dive into HTML5", published=datetime.date(2010, 3, 15)) + # initially, no db assigned + self.assertIsNone(chris._state.db) + self.assertIsNone(html5._state.db) + + # old object comes from 'other', so the new object is set to use the + # source of 'other'... + self.assertEqual(dive._state.db, 'other') + chris.save() + dive.editor = chris + html5.editor = mark + + self.assertEqual(dive._state.db, 'other') + self.assertEqual(mark._state.db, 'other') + self.assertEqual(chris._state.db, 'default') + self.assertEqual(html5._state.db, 'default') + + # This also works if you assign the FK in the constructor + water = Book(title="Dive into Water", published=datetime.date(2001, 1, 1), editor=mark) + self.assertEqual(water._state.db, 'default') + + # For the remainder of this test, create a copy of 'mark' in the + # 'default' database to prevent integrity errors on backends that + # don't defer constraints checks until the end of the transaction + mark.save(using='default') + + # This moved 'mark' in the 'default' database, move it back in 'other' + mark.save(using='other') + self.assertEqual(mark._state.db, 'other') + + # If you create an object through a FK relation, it will be + # written to the write database, even if the original object + # was on the read database + cheesecake = mark.edited.create(title='Dive into Cheesecake', published=datetime.date(2010, 3, 15)) + self.assertEqual(cheesecake._state.db, 'default') + + # Same goes for get_or_create, regardless of whether getting or creating + cheesecake, created = mark.edited.get_or_create( + title='Dive into Cheesecake', + published=datetime.date(2010, 3, 15), + ) + self.assertEqual(cheesecake._state.db, 'default') + + puddles, created = mark.edited.get_or_create(title='Dive into Puddles', published=datetime.date(2010, 3, 15)) + self.assertEqual(puddles._state.db, 'default') + + def test_m2m_cross_database_protection(self): + "M2M relations can cross databases if the database share a source" + # Create books and authors on the inverse to the usual database + pro = Book.objects.using('other').create(pk=1, title="Pro Django", + published=datetime.date(2008, 12, 16)) + + marty = Person.objects.using('other').create(pk=1, name="Marty Alchin") + + dive = Book.objects.using('default').create(pk=2, title="Dive into Python", + published=datetime.date(2009, 5, 4)) + + mark = Person.objects.using('default').create(pk=2, name="Mark Pilgrim") + + # Now save back onto the usual database. + # This simulates primary/replica - the objects exist on both database, + # but the _state.db is as it is for all other tests. + pro.save(using='default') + marty.save(using='default') + dive.save(using='other') + mark.save(using='other') + + # We have 2 of both types of object on both databases + self.assertEqual(Book.objects.using('default').count(), 2) + self.assertEqual(Book.objects.using('other').count(), 2) + self.assertEqual(Person.objects.using('default').count(), 2) + self.assertEqual(Person.objects.using('other').count(), 2) + + # Set a m2m set with an object from a different database + marty.book_set.set([pro, dive]) + + # Database assignments don't change + self.assertEqual(marty._state.db, 'default') + self.assertEqual(pro._state.db, 'default') + self.assertEqual(dive._state.db, 'other') + self.assertEqual(mark._state.db, 'other') + + # All m2m relations should be saved on the default database + self.assertEqual(Book.authors.through.objects.using('default').count(), 2) + self.assertEqual(Book.authors.through.objects.using('other').count(), 0) + + # Reset relations + Book.authors.through.objects.using('default').delete() + + # Add to an m2m with an object from a different database + marty.book_set.add(dive) + + # Database assignments don't change + self.assertEqual(marty._state.db, 'default') + self.assertEqual(pro._state.db, 'default') + self.assertEqual(dive._state.db, 'other') + self.assertEqual(mark._state.db, 'other') + + # All m2m relations should be saved on the default database + self.assertEqual(Book.authors.through.objects.using('default').count(), 1) + self.assertEqual(Book.authors.through.objects.using('other').count(), 0) + + # Reset relations + Book.authors.through.objects.using('default').delete() + + # Set a reverse m2m with an object from a different database + dive.authors.set([mark, marty]) + + # Database assignments don't change + self.assertEqual(marty._state.db, 'default') + self.assertEqual(pro._state.db, 'default') + self.assertEqual(dive._state.db, 'other') + self.assertEqual(mark._state.db, 'other') + + # All m2m relations should be saved on the default database + self.assertEqual(Book.authors.through.objects.using('default').count(), 2) + self.assertEqual(Book.authors.through.objects.using('other').count(), 0) + + # Reset relations + Book.authors.through.objects.using('default').delete() + + self.assertEqual(Book.authors.through.objects.using('default').count(), 0) + self.assertEqual(Book.authors.through.objects.using('other').count(), 0) + + # Add to a reverse m2m with an object from a different database + dive.authors.add(marty) + + # Database assignments don't change + self.assertEqual(marty._state.db, 'default') + self.assertEqual(pro._state.db, 'default') + self.assertEqual(dive._state.db, 'other') + self.assertEqual(mark._state.db, 'other') + + # All m2m relations should be saved on the default database + self.assertEqual(Book.authors.through.objects.using('default').count(), 1) + self.assertEqual(Book.authors.through.objects.using('other').count(), 0) + + # If you create an object through a M2M relation, it will be + # written to the write database, even if the original object + # was on the read database + alice = dive.authors.create(name='Alice') + self.assertEqual(alice._state.db, 'default') + + # Same goes for get_or_create, regardless of whether getting or creating + alice, created = dive.authors.get_or_create(name='Alice') + self.assertEqual(alice._state.db, 'default') + + bob, created = dive.authors.get_or_create(name='Bob') + self.assertEqual(bob._state.db, 'default') + + def test_o2o_cross_database_protection(self): + "Operations that involve sharing FK objects across databases raise an error" + # Create a user and profile on the default database + alice = User.objects.db_manager('default').create_user('alice', 'alice@example.com') + + # Create a user and profile on the other database + bob = User.objects.db_manager('other').create_user('bob', 'bob@example.com') + + # Set a one-to-one relation with an object from a different database + alice_profile = UserProfile.objects.create(user=alice, flavor='chocolate') + bob.userprofile = alice_profile + + # Database assignments of original objects haven't changed... + self.assertEqual(alice._state.db, 'default') + self.assertEqual(alice_profile._state.db, 'default') + self.assertEqual(bob._state.db, 'other') + + # ... but they will when the affected object is saved. + bob.save() + self.assertEqual(bob._state.db, 'default') + + def test_generic_key_cross_database_protection(self): + "Generic Key operations can span databases if they share a source" + # Create a book and author on the default database + pro = Book.objects.using( + 'default').create(title="Pro Django", published=datetime.date(2008, 12, 16)) + + review1 = Review.objects.using( + 'default').create(source="Python Monthly", content_object=pro) + + # Create a book and author on the other database + dive = Book.objects.using( + 'other').create(title="Dive into Python", published=datetime.date(2009, 5, 4)) + + review2 = Review.objects.using( + 'other').create(source="Python Weekly", content_object=dive) + + # Set a generic foreign key with an object from a different database + review1.content_object = dive + + # Database assignments of original objects haven't changed... + self.assertEqual(pro._state.db, 'default') + self.assertEqual(review1._state.db, 'default') + self.assertEqual(dive._state.db, 'other') + self.assertEqual(review2._state.db, 'other') + + # ... but they will when the affected object is saved. + dive.save() + self.assertEqual(review1._state.db, 'default') + self.assertEqual(dive._state.db, 'default') + + # ...and the source database now has a copy of any object saved + Book.objects.using('default').get(title='Dive into Python').delete() + + # This isn't a real primary/replica database, so restore the original from other + dive = Book.objects.using('other').get(title='Dive into Python') + self.assertEqual(dive._state.db, 'other') + + # Add to a generic foreign key set with an object from a different database + dive.reviews.add(review1) + + # Database assignments of original objects haven't changed... + self.assertEqual(pro._state.db, 'default') + self.assertEqual(review1._state.db, 'default') + self.assertEqual(dive._state.db, 'other') + self.assertEqual(review2._state.db, 'other') + + # ... but they will when the affected object is saved. + dive.save() + self.assertEqual(dive._state.db, 'default') + + # ...and the source database now has a copy of any object saved + Book.objects.using('default').get(title='Dive into Python').delete() + + # BUT! if you assign a FK object when the base object hasn't + # been saved yet, you implicitly assign the database for the + # base object. + review3 = Review(source="Python Daily") + # initially, no db assigned + self.assertIsNone(review3._state.db) + + # Dive comes from 'other', so review3 is set to use the source of 'other'... + review3.content_object = dive + self.assertEqual(review3._state.db, 'default') + + # If you create an object through a M2M relation, it will be + # written to the write database, even if the original object + # was on the read database + dive = Book.objects.using('other').get(title='Dive into Python') + nyt = dive.reviews.create(source="New York Times", content_object=dive) + self.assertEqual(nyt._state.db, 'default') + + def test_m2m_managers(self): + "M2M relations are represented by managers, and can be controlled like managers" + pro = Book.objects.using('other').create(pk=1, title="Pro Django", + published=datetime.date(2008, 12, 16)) + + marty = Person.objects.using('other').create(pk=1, name="Marty Alchin") + + self.assertEqual(pro.authors.db, 'other') + self.assertEqual(pro.authors.db_manager('default').db, 'default') + self.assertEqual(pro.authors.db_manager('default').all().db, 'default') + + self.assertEqual(marty.book_set.db, 'other') + self.assertEqual(marty.book_set.db_manager('default').db, 'default') + self.assertEqual(marty.book_set.db_manager('default').all().db, 'default') + + def test_foreign_key_managers(self): + "FK reverse relations are represented by managers, and can be controlled like managers" + marty = Person.objects.using('other').create(pk=1, name="Marty Alchin") + Book.objects.using('other').create(pk=1, title="Pro Django", + published=datetime.date(2008, 12, 16), + editor=marty) + + self.assertEqual(marty.edited.db, 'other') + self.assertEqual(marty.edited.db_manager('default').db, 'default') + self.assertEqual(marty.edited.db_manager('default').all().db, 'default') + + def test_generic_key_managers(self): + "Generic key relations are represented by managers, and can be controlled like managers" + pro = Book.objects.using('other').create(title="Pro Django", + published=datetime.date(2008, 12, 16)) + + Review.objects.using('other').create(source="Python Monthly", + content_object=pro) + + self.assertEqual(pro.reviews.db, 'other') + self.assertEqual(pro.reviews.db_manager('default').db, 'default') + self.assertEqual(pro.reviews.db_manager('default').all().db, 'default') + + def test_subquery(self): + """Make sure as_sql works with subqueries and primary/replica.""" + # Create a book and author on the other database + + mark = Person.objects.using('other').create(name="Mark Pilgrim") + Book.objects.using('other').create(title="Dive into Python", + published=datetime.date(2009, 5, 4), + editor=mark) + + sub = Person.objects.filter(name='Mark Pilgrim') + qs = Book.objects.filter(editor__in=sub) + + # When you call __str__ on the query object, it doesn't know about using + # so it falls back to the default. Don't let routing instructions + # force the subquery to an incompatible database. + str(qs.query) + + # If you evaluate the query, it should work, running on 'other' + self.assertEqual(list(qs.values_list('title', flat=True)), ['Dive into Python']) + + def test_deferred_models(self): + mark_def = Person.objects.using('default').create(name="Mark Pilgrim") + mark_other = Person.objects.using('other').create(name="Mark Pilgrim") + orig_b = Book.objects.using('other').create(title="Dive into Python", + published=datetime.date(2009, 5, 4), + editor=mark_other) + b = Book.objects.using('other').only('title').get(pk=orig_b.pk) + self.assertEqual(b.published, datetime.date(2009, 5, 4)) + b = Book.objects.using('other').only('title').get(pk=orig_b.pk) + b.editor = mark_def + b.save(using='default') + self.assertEqual(Book.objects.using('default').get(pk=b.pk).published, + datetime.date(2009, 5, 4)) + + +@override_settings(DATABASE_ROUTERS=[AuthRouter()]) +class AuthTestCase(TestCase): + multi_db = True + + def test_auth_manager(self): + "The methods on the auth manager obey database hints" + # Create one user using default allocation policy + User.objects.create_user('alice', 'alice@example.com') + + # Create another user, explicitly specifying the database + User.objects.db_manager('default').create_user('bob', 'bob@example.com') + + # The second user only exists on the other database + alice = User.objects.using('other').get(username='alice') + + self.assertEqual(alice.username, 'alice') + self.assertEqual(alice._state.db, 'other') + + with self.assertRaises(User.DoesNotExist): + User.objects.using('default').get(username='alice') + + # The second user only exists on the default database + bob = User.objects.using('default').get(username='bob') + + self.assertEqual(bob.username, 'bob') + self.assertEqual(bob._state.db, 'default') + + with self.assertRaises(User.DoesNotExist): + User.objects.using('other').get(username='bob') + + # That is... there is one user on each database + self.assertEqual(User.objects.using('default').count(), 1) + self.assertEqual(User.objects.using('other').count(), 1) + + def test_dumpdata(self): + "dumpdata honors allow_migrate restrictions on the router" + User.objects.create_user('alice', 'alice@example.com') + User.objects.db_manager('default').create_user('bob', 'bob@example.com') + + # dumping the default database doesn't try to include auth because + # allow_migrate prohibits auth on default + new_io = StringIO() + management.call_command('dumpdata', 'auth', format='json', database='default', stdout=new_io) + command_output = new_io.getvalue().strip() + self.assertEqual(command_output, '[]') + + # dumping the other database does include auth + new_io = StringIO() + management.call_command('dumpdata', 'auth', format='json', database='other', stdout=new_io) + command_output = new_io.getvalue().strip() + self.assertIn('"email": "alice@example.com"', command_output) + + +class AntiPetRouter(object): + # A router that only expresses an opinion on migrate, + # passing pets to the 'other' database + + def allow_migrate(self, db, app_label, model_name=None, **hints): + if db == 'other': + return model_name == 'pet' + else: + return model_name != 'pet' + + +class FixtureTestCase(TestCase): + multi_db = True + fixtures = ['multidb-common', 'multidb'] + + @override_settings(DATABASE_ROUTERS=[AntiPetRouter()]) + def test_fixture_loading(self): + "Multi-db fixtures are loaded correctly" + # "Pro Django" exists on the default database, but not on other database + Book.objects.get(title="Pro Django") + Book.objects.using('default').get(title="Pro Django") + + with self.assertRaises(Book.DoesNotExist): + Book.objects.using('other').get(title="Pro Django") + + # "Dive into Python" exists on the default database, but not on other database + Book.objects.using('other').get(title="Dive into Python") + + with self.assertRaises(Book.DoesNotExist): + Book.objects.get(title="Dive into Python") + with self.assertRaises(Book.DoesNotExist): + Book.objects.using('default').get(title="Dive into Python") + + # "Definitive Guide" exists on the both databases + Book.objects.get(title="The Definitive Guide to Django") + Book.objects.using('default').get(title="The Definitive Guide to Django") + Book.objects.using('other').get(title="The Definitive Guide to Django") + + @override_settings(DATABASE_ROUTERS=[AntiPetRouter()]) + def test_pseudo_empty_fixtures(self): + """ + A fixture can contain entries, but lead to nothing in the database; + this shouldn't raise an error (#14068). + """ + new_io = StringIO() + management.call_command('loaddata', 'pets', stdout=new_io, stderr=new_io) + command_output = new_io.getvalue().strip() + # No objects will actually be loaded + self.assertEqual(command_output, "Installed 0 object(s) (of 2) from 1 fixture(s)") + + +class PickleQuerySetTestCase(TestCase): + multi_db = True + + def test_pickling(self): + for db in connections: + Book.objects.using(db).create(title='Dive into Python', published=datetime.date(2009, 5, 4)) + qs = Book.objects.all() + self.assertEqual(qs.db, pickle.loads(pickle.dumps(qs)).db) + + +class DatabaseReceiver(object): + """ + Used in the tests for the database argument in signals (#13552) + """ + def __call__(self, signal, sender, **kwargs): + self._database = kwargs['using'] + + +class WriteToOtherRouter(object): + """ + A router that sends all writes to the other database. + """ + def db_for_write(self, model, **hints): + return "other" + + +class SignalTests(TestCase): + multi_db = True + + def override_router(self): + return override_settings(DATABASE_ROUTERS=[WriteToOtherRouter()]) + + def test_database_arg_save_and_delete(self): + """ + The pre/post_save signal contains the correct database. + """ + # Make some signal receivers + pre_save_receiver = DatabaseReceiver() + post_save_receiver = DatabaseReceiver() + pre_delete_receiver = DatabaseReceiver() + post_delete_receiver = DatabaseReceiver() + # Make model and connect receivers + signals.pre_save.connect(sender=Person, receiver=pre_save_receiver) + signals.post_save.connect(sender=Person, receiver=post_save_receiver) + signals.pre_delete.connect(sender=Person, receiver=pre_delete_receiver) + signals.post_delete.connect(sender=Person, receiver=post_delete_receiver) + p = Person.objects.create(name='Darth Vader') + # Save and test receivers got calls + p.save() + self.assertEqual(pre_save_receiver._database, DEFAULT_DB_ALIAS) + self.assertEqual(post_save_receiver._database, DEFAULT_DB_ALIAS) + # Delete, and test + p.delete() + self.assertEqual(pre_delete_receiver._database, DEFAULT_DB_ALIAS) + self.assertEqual(post_delete_receiver._database, DEFAULT_DB_ALIAS) + # Save again to a different database + p.save(using="other") + self.assertEqual(pre_save_receiver._database, "other") + self.assertEqual(post_save_receiver._database, "other") + # Delete, and test + p.delete(using="other") + self.assertEqual(pre_delete_receiver._database, "other") + self.assertEqual(post_delete_receiver._database, "other") + + signals.pre_save.disconnect(sender=Person, receiver=pre_save_receiver) + signals.post_save.disconnect(sender=Person, receiver=post_save_receiver) + signals.pre_delete.disconnect(sender=Person, receiver=pre_delete_receiver) + signals.post_delete.disconnect(sender=Person, receiver=post_delete_receiver) + + def test_database_arg_m2m(self): + """ + The m2m_changed signal has a correct database arg. + """ + # Make a receiver + receiver = DatabaseReceiver() + # Connect it + signals.m2m_changed.connect(receiver=receiver) + + # Create the models that will be used for the tests + b = Book.objects.create(title="Pro Django", + published=datetime.date(2008, 12, 16)) + p = Person.objects.create(name="Marty Alchin") + + # Create a copy of the models on the 'other' database to prevent + # integrity errors on backends that don't defer constraints checks + Book.objects.using('other').create(pk=b.pk, title=b.title, + published=b.published) + Person.objects.using('other').create(pk=p.pk, name=p.name) + + # Test addition + b.authors.add(p) + self.assertEqual(receiver._database, DEFAULT_DB_ALIAS) + with self.override_router(): + b.authors.add(p) + self.assertEqual(receiver._database, "other") + + # Test removal + b.authors.remove(p) + self.assertEqual(receiver._database, DEFAULT_DB_ALIAS) + with self.override_router(): + b.authors.remove(p) + self.assertEqual(receiver._database, "other") + + # Test addition in reverse + p.book_set.add(b) + self.assertEqual(receiver._database, DEFAULT_DB_ALIAS) + with self.override_router(): + p.book_set.add(b) + self.assertEqual(receiver._database, "other") + + # Test clearing + b.authors.clear() + self.assertEqual(receiver._database, DEFAULT_DB_ALIAS) + with self.override_router(): + b.authors.clear() + self.assertEqual(receiver._database, "other") + + +class AttributeErrorRouter(object): + "A router to test the exception handling of ConnectionRouter" + def db_for_read(self, model, **hints): + raise AttributeError + + def db_for_write(self, model, **hints): + raise AttributeError + + +class RouterAttributeErrorTestCase(TestCase): + multi_db = True + + def override_router(self): + return override_settings(DATABASE_ROUTERS=[AttributeErrorRouter()]) + + def test_attribute_error_read(self): + "The AttributeError from AttributeErrorRouter bubbles up" + b = Book.objects.create(title="Pro Django", + published=datetime.date(2008, 12, 16)) + with self.override_router(): + with self.assertRaises(AttributeError): + Book.objects.get(pk=b.pk) + + def test_attribute_error_save(self): + "The AttributeError from AttributeErrorRouter bubbles up" + dive = Book() + dive.title = "Dive into Python" + dive.published = datetime.date(2009, 5, 4) + with self.override_router(): + with self.assertRaises(AttributeError): + dive.save() + + def test_attribute_error_delete(self): + "The AttributeError from AttributeErrorRouter bubbles up" + b = Book.objects.create(title="Pro Django", + published=datetime.date(2008, 12, 16)) + p = Person.objects.create(name="Marty Alchin") + b.authors.set([p]) + b.editor = p + with self.override_router(): + with self.assertRaises(AttributeError): + b.delete() + + def test_attribute_error_m2m(self): + "The AttributeError from AttributeErrorRouter bubbles up" + b = Book.objects.create(title="Pro Django", + published=datetime.date(2008, 12, 16)) + p = Person.objects.create(name="Marty Alchin") + with self.override_router(): + with self.assertRaises(AttributeError): + b.authors.set([p]) + + +class ModelMetaRouter(object): + "A router to ensure model arguments are real model classes" + def db_for_write(self, model, **hints): + if not hasattr(model, '_meta'): + raise ValueError + + +@override_settings(DATABASE_ROUTERS=[ModelMetaRouter()]) +class RouterModelArgumentTestCase(TestCase): + multi_db = True + + def test_m2m_collection(self): + b = Book.objects.create(title="Pro Django", + published=datetime.date(2008, 12, 16)) + + p = Person.objects.create(name="Marty Alchin") + # test add + b.authors.add(p) + # test remove + b.authors.remove(p) + # test clear + b.authors.clear() + # test setattr + b.authors.set([p]) + # test M2M collection + b.delete() + + def test_foreignkey_collection(self): + person = Person.objects.create(name='Bob') + Pet.objects.create(owner=person, name='Wart') + # test related FK collection + person.delete() + + +class SyncOnlyDefaultDatabaseRouter(object): + def allow_migrate(self, db, app_label, **hints): + return db == DEFAULT_DB_ALIAS + + +class MigrateTestCase(TestCase): + + # Limit memory usage when calling 'migrate'. + available_apps = [ + 'multiple_database', + 'django.contrib.auth', + 'django.contrib.contenttypes' + ] + multi_db = True + + def test_migrate_to_other_database(self): + """Regression test for #16039: migrate with --database option.""" + cts = ContentType.objects.using('other').filter(app_label='multiple_database') + + count = cts.count() + self.assertGreater(count, 0) + + cts.delete() + management.call_command('migrate', verbosity=0, interactive=False, database='other') + self.assertEqual(cts.count(), count) + + def test_migrate_to_other_database_with_router(self): + """Regression test for #16039: migrate with --database option.""" + cts = ContentType.objects.using('other').filter(app_label='multiple_database') + + cts.delete() + with override_settings(DATABASE_ROUTERS=[SyncOnlyDefaultDatabaseRouter()]): + management.call_command('migrate', verbosity=0, interactive=False, database='other') + + self.assertEqual(cts.count(), 0) + + +class RouterUsed(Exception): + WRITE = 'write' + + def __init__(self, mode, model, hints): + self.mode = mode + self.model = model + self.hints = hints + + +class RouteForWriteTestCase(TestCase): + multi_db = True + + class WriteCheckRouter(object): + def db_for_write(self, model, **hints): + raise RouterUsed(mode=RouterUsed.WRITE, model=model, hints=hints) + + def override_router(self): + return override_settings(DATABASE_ROUTERS=[RouteForWriteTestCase.WriteCheckRouter()]) + + def test_fk_delete(self): + owner = Person.objects.create(name='Someone') + pet = Pet.objects.create(name='fido', owner=owner) + with self.assertRaises(RouterUsed) as cm: + with self.override_router(): + pet.owner.delete() + e = cm.exception + self.assertEqual(e.mode, RouterUsed.WRITE) + self.assertEqual(e.model, Person) + self.assertEqual(e.hints, {'instance': owner}) + + def test_reverse_fk_delete(self): + owner = Person.objects.create(name='Someone') + to_del_qs = owner.pet_set.all() + with self.assertRaises(RouterUsed) as cm: + with self.override_router(): + to_del_qs.delete() + e = cm.exception + self.assertEqual(e.mode, RouterUsed.WRITE) + self.assertEqual(e.model, Pet) + self.assertEqual(e.hints, {'instance': owner}) + + def test_reverse_fk_get_or_create(self): + owner = Person.objects.create(name='Someone') + with self.assertRaises(RouterUsed) as cm: + with self.override_router(): + owner.pet_set.get_or_create(name='fido') + e = cm.exception + self.assertEqual(e.mode, RouterUsed.WRITE) + self.assertEqual(e.model, Pet) + self.assertEqual(e.hints, {'instance': owner}) + + def test_reverse_fk_update(self): + owner = Person.objects.create(name='Someone') + Pet.objects.create(name='fido', owner=owner) + with self.assertRaises(RouterUsed) as cm: + with self.override_router(): + owner.pet_set.update(name='max') + e = cm.exception + self.assertEqual(e.mode, RouterUsed.WRITE) + self.assertEqual(e.model, Pet) + self.assertEqual(e.hints, {'instance': owner}) + + def test_m2m_add(self): + auth = Person.objects.create(name='Someone') + book = Book.objects.create(title="Pro Django", + published=datetime.date(2008, 12, 16)) + with self.assertRaises(RouterUsed) as cm: + with self.override_router(): + book.authors.add(auth) + e = cm.exception + self.assertEqual(e.mode, RouterUsed.WRITE) + self.assertEqual(e.model, Book.authors.through) + self.assertEqual(e.hints, {'instance': book}) + + def test_m2m_clear(self): + auth = Person.objects.create(name='Someone') + book = Book.objects.create(title="Pro Django", + published=datetime.date(2008, 12, 16)) + book.authors.add(auth) + with self.assertRaises(RouterUsed) as cm: + with self.override_router(): + book.authors.clear() + e = cm.exception + self.assertEqual(e.mode, RouterUsed.WRITE) + self.assertEqual(e.model, Book.authors.through) + self.assertEqual(e.hints, {'instance': book}) + + def test_m2m_delete(self): + auth = Person.objects.create(name='Someone') + book = Book.objects.create(title="Pro Django", + published=datetime.date(2008, 12, 16)) + book.authors.add(auth) + with self.assertRaises(RouterUsed) as cm: + with self.override_router(): + book.authors.all().delete() + e = cm.exception + self.assertEqual(e.mode, RouterUsed.WRITE) + self.assertEqual(e.model, Person) + self.assertEqual(e.hints, {'instance': book}) + + def test_m2m_get_or_create(self): + Person.objects.create(name='Someone') + book = Book.objects.create(title="Pro Django", + published=datetime.date(2008, 12, 16)) + with self.assertRaises(RouterUsed) as cm: + with self.override_router(): + book.authors.get_or_create(name='Someone else') + e = cm.exception + self.assertEqual(e.mode, RouterUsed.WRITE) + self.assertEqual(e.model, Book) + self.assertEqual(e.hints, {'instance': book}) + + def test_m2m_remove(self): + auth = Person.objects.create(name='Someone') + book = Book.objects.create(title="Pro Django", + published=datetime.date(2008, 12, 16)) + book.authors.add(auth) + with self.assertRaises(RouterUsed) as cm: + with self.override_router(): + book.authors.remove(auth) + e = cm.exception + self.assertEqual(e.mode, RouterUsed.WRITE) + self.assertEqual(e.model, Book.authors.through) + self.assertEqual(e.hints, {'instance': book}) + + def test_m2m_update(self): + auth = Person.objects.create(name='Someone') + book = Book.objects.create(title="Pro Django", + published=datetime.date(2008, 12, 16)) + book.authors.add(auth) + with self.assertRaises(RouterUsed) as cm: + with self.override_router(): + book.authors.all().update(name='Different') + e = cm.exception + self.assertEqual(e.mode, RouterUsed.WRITE) + self.assertEqual(e.model, Person) + self.assertEqual(e.hints, {'instance': book}) + + def test_reverse_m2m_add(self): + auth = Person.objects.create(name='Someone') + book = Book.objects.create(title="Pro Django", + published=datetime.date(2008, 12, 16)) + with self.assertRaises(RouterUsed) as cm: + with self.override_router(): + auth.book_set.add(book) + e = cm.exception + self.assertEqual(e.mode, RouterUsed.WRITE) + self.assertEqual(e.model, Book.authors.through) + self.assertEqual(e.hints, {'instance': auth}) + + def test_reverse_m2m_clear(self): + auth = Person.objects.create(name='Someone') + book = Book.objects.create(title="Pro Django", + published=datetime.date(2008, 12, 16)) + book.authors.add(auth) + with self.assertRaises(RouterUsed) as cm: + with self.override_router(): + auth.book_set.clear() + e = cm.exception + self.assertEqual(e.mode, RouterUsed.WRITE) + self.assertEqual(e.model, Book.authors.through) + self.assertEqual(e.hints, {'instance': auth}) + + def test_reverse_m2m_delete(self): + auth = Person.objects.create(name='Someone') + book = Book.objects.create(title="Pro Django", + published=datetime.date(2008, 12, 16)) + book.authors.add(auth) + with self.assertRaises(RouterUsed) as cm: + with self.override_router(): + auth.book_set.all().delete() + e = cm.exception + self.assertEqual(e.mode, RouterUsed.WRITE) + self.assertEqual(e.model, Book) + self.assertEqual(e.hints, {'instance': auth}) + + def test_reverse_m2m_get_or_create(self): + auth = Person.objects.create(name='Someone') + Book.objects.create(title="Pro Django", + published=datetime.date(2008, 12, 16)) + with self.assertRaises(RouterUsed) as cm: + with self.override_router(): + auth.book_set.get_or_create(title="New Book", published=datetime.datetime.now()) + e = cm.exception + self.assertEqual(e.mode, RouterUsed.WRITE) + self.assertEqual(e.model, Person) + self.assertEqual(e.hints, {'instance': auth}) + + def test_reverse_m2m_remove(self): + auth = Person.objects.create(name='Someone') + book = Book.objects.create(title="Pro Django", + published=datetime.date(2008, 12, 16)) + book.authors.add(auth) + with self.assertRaises(RouterUsed) as cm: + with self.override_router(): + auth.book_set.remove(book) + e = cm.exception + self.assertEqual(e.mode, RouterUsed.WRITE) + self.assertEqual(e.model, Book.authors.through) + self.assertEqual(e.hints, {'instance': auth}) + + def test_reverse_m2m_update(self): + auth = Person.objects.create(name='Someone') + book = Book.objects.create(title="Pro Django", + published=datetime.date(2008, 12, 16)) + book.authors.add(auth) + with self.assertRaises(RouterUsed) as cm: + with self.override_router(): + auth.book_set.all().update(title='Different') + e = cm.exception + self.assertEqual(e.mode, RouterUsed.WRITE) + self.assertEqual(e.model, Book) + self.assertEqual(e.hints, {'instance': auth}) diff --git a/tests/nested_foreign_keys/__init__.py b/tests/nested_foreign_keys/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/nested_foreign_keys/models.py b/tests/nested_foreign_keys/models.py new file mode 100644 index 00000000..5805de5d --- /dev/null +++ b/tests/nested_foreign_keys/models.py @@ -0,0 +1,30 @@ +from django.db import models + + +class Person(models.Model): + name = models.CharField(max_length=200) + + +class Movie(models.Model): + title = models.CharField(max_length=200) + director = models.ForeignKey(Person, models.CASCADE) + + +class Event(models.Model): + pass + + +class Screening(Event): + movie = models.ForeignKey(Movie, models.CASCADE) + + +class ScreeningNullFK(Event): + movie = models.ForeignKey(Movie, models.SET_NULL, null=True) + + +class Package(models.Model): + screening = models.ForeignKey(Screening, models.SET_NULL, null=True) + + +class PackageNullFK(models.Model): + screening = models.ForeignKey(ScreeningNullFK, models.SET_NULL, null=True) diff --git a/tests/nested_foreign_keys/tests.py b/tests/nested_foreign_keys/tests.py new file mode 100644 index 00000000..34a3703e --- /dev/null +++ b/tests/nested_foreign_keys/tests.py @@ -0,0 +1,176 @@ +from __future__ import unicode_literals + +from django.test import TestCase + +from .models import ( + Event, Movie, Package, PackageNullFK, Person, Screening, ScreeningNullFK, +) + + +# These are tests for #16715. The basic scheme is always the same: 3 models with +# 2 relations. The first relation may be null, while the second is non-nullable. +# In some cases, Django would pick the wrong join type for the second relation, +# resulting in missing objects in the queryset. +# +# Model A +# | (Relation A/B : nullable) +# Model B +# | (Relation B/C : non-nullable) +# Model C +# +# Because of the possibility of NULL rows resulting from the LEFT OUTER JOIN +# between Model A and Model B (i.e. instances of A without reference to B), +# the second join must also be LEFT OUTER JOIN, so that we do not ignore +# instances of A that do not reference B. +# +# Relation A/B can either be an explicit foreign key or an implicit reverse +# relation such as introduced by one-to-one relations (through multi-table +# inheritance). +class NestedForeignKeysTests(TestCase): + def setUp(self): + self.director = Person.objects.create(name='Terry Gilliam / Terry Jones') + self.movie = Movie.objects.create(title='Monty Python and the Holy Grail', director=self.director) + + # This test failed in #16715 because in some cases INNER JOIN was selected + # for the second foreign key relation instead of LEFT OUTER JOIN. + def test_inheritance(self): + Event.objects.create() + Screening.objects.create(movie=self.movie) + + self.assertEqual(len(Event.objects.all()), 2) + self.assertEqual(len(Event.objects.select_related('screening')), 2) + # This failed. + self.assertEqual(len(Event.objects.select_related('screening__movie')), 2) + + self.assertEqual(len(Event.objects.values()), 2) + self.assertEqual(len(Event.objects.values('screening__pk')), 2) + self.assertEqual(len(Event.objects.values('screening__movie__pk')), 2) + self.assertEqual(len(Event.objects.values('screening__movie__title')), 2) + # This failed. + self.assertEqual(len(Event.objects.values('screening__movie__pk', 'screening__movie__title')), 2) + + # Simple filter/exclude queries for good measure. + self.assertEqual(Event.objects.filter(screening__movie=self.movie).count(), 1) + self.assertEqual(Event.objects.exclude(screening__movie=self.movie).count(), 1) + + # These all work because the second foreign key in the chain has null=True. + def test_inheritance_null_FK(self): + Event.objects.create() + ScreeningNullFK.objects.create(movie=None) + ScreeningNullFK.objects.create(movie=self.movie) + + self.assertEqual(len(Event.objects.all()), 3) + self.assertEqual(len(Event.objects.select_related('screeningnullfk')), 3) + self.assertEqual(len(Event.objects.select_related('screeningnullfk__movie')), 3) + + self.assertEqual(len(Event.objects.values()), 3) + self.assertEqual(len(Event.objects.values('screeningnullfk__pk')), 3) + self.assertEqual(len(Event.objects.values('screeningnullfk__movie__pk')), 3) + self.assertEqual(len(Event.objects.values('screeningnullfk__movie__title')), 3) + self.assertEqual(len(Event.objects.values('screeningnullfk__movie__pk', 'screeningnullfk__movie__title')), 3) + + self.assertEqual(Event.objects.filter(screeningnullfk__movie=self.movie).count(), 1) + self.assertEqual(Event.objects.exclude(screeningnullfk__movie=self.movie).count(), 2) + + def test_null_exclude(self): + screening = ScreeningNullFK.objects.create(movie=None) + ScreeningNullFK.objects.create(movie=self.movie) + self.assertEqual( + list(ScreeningNullFK.objects.exclude(movie__id=self.movie.pk)), + [screening]) + + # This test failed in #16715 because in some cases INNER JOIN was selected + # for the second foreign key relation instead of LEFT OUTER JOIN. + def test_explicit_ForeignKey(self): + Package.objects.create() + screening = Screening.objects.create(movie=self.movie) + Package.objects.create(screening=screening) + + self.assertEqual(len(Package.objects.all()), 2) + self.assertEqual(len(Package.objects.select_related('screening')), 2) + self.assertEqual(len(Package.objects.select_related('screening__movie')), 2) + + self.assertEqual(len(Package.objects.values()), 2) + self.assertEqual(len(Package.objects.values('screening__pk')), 2) + self.assertEqual(len(Package.objects.values('screening__movie__pk')), 2) + self.assertEqual(len(Package.objects.values('screening__movie__title')), 2) + # This failed. + self.assertEqual(len(Package.objects.values('screening__movie__pk', 'screening__movie__title')), 2) + + self.assertEqual(Package.objects.filter(screening__movie=self.movie).count(), 1) + self.assertEqual(Package.objects.exclude(screening__movie=self.movie).count(), 1) + + # These all work because the second foreign key in the chain has null=True. + def test_explicit_ForeignKey_NullFK(self): + PackageNullFK.objects.create() + screening = ScreeningNullFK.objects.create(movie=None) + screening_with_movie = ScreeningNullFK.objects.create(movie=self.movie) + PackageNullFK.objects.create(screening=screening) + PackageNullFK.objects.create(screening=screening_with_movie) + + self.assertEqual(len(PackageNullFK.objects.all()), 3) + self.assertEqual(len(PackageNullFK.objects.select_related('screening')), 3) + self.assertEqual(len(PackageNullFK.objects.select_related('screening__movie')), 3) + + self.assertEqual(len(PackageNullFK.objects.values()), 3) + self.assertEqual(len(PackageNullFK.objects.values('screening__pk')), 3) + self.assertEqual(len(PackageNullFK.objects.values('screening__movie__pk')), 3) + self.assertEqual(len(PackageNullFK.objects.values('screening__movie__title')), 3) + self.assertEqual(len(PackageNullFK.objects.values('screening__movie__pk', 'screening__movie__title')), 3) + + self.assertEqual(PackageNullFK.objects.filter(screening__movie=self.movie).count(), 1) + self.assertEqual(PackageNullFK.objects.exclude(screening__movie=self.movie).count(), 2) + + +# Some additional tests for #16715. The only difference is the depth of the +# nesting as we now use 4 models instead of 3 (and thus 3 relations). This +# checks if promotion of join types works for deeper nesting too. +class DeeplyNestedForeignKeysTests(TestCase): + def setUp(self): + self.director = Person.objects.create(name='Terry Gilliam / Terry Jones') + self.movie = Movie.objects.create(title='Monty Python and the Holy Grail', director=self.director) + + def test_inheritance(self): + Event.objects.create() + Screening.objects.create(movie=self.movie) + + self.assertEqual(len(Event.objects.all()), 2) + self.assertEqual(len(Event.objects.select_related('screening__movie__director')), 2) + + self.assertEqual(len(Event.objects.values()), 2) + self.assertEqual(len(Event.objects.values('screening__movie__director__pk')), 2) + self.assertEqual(len(Event.objects.values('screening__movie__director__name')), 2) + self.assertEqual( + len(Event.objects.values('screening__movie__director__pk', 'screening__movie__director__name')), + 2 + ) + self.assertEqual(len(Event.objects.values('screening__movie__pk', 'screening__movie__director__pk')), 2) + self.assertEqual(len(Event.objects.values('screening__movie__pk', 'screening__movie__director__name')), 2) + self.assertEqual(len(Event.objects.values('screening__movie__title', 'screening__movie__director__pk')), 2) + self.assertEqual(len(Event.objects.values('screening__movie__title', 'screening__movie__director__name')), 2) + + self.assertEqual(Event.objects.filter(screening__movie__director=self.director).count(), 1) + self.assertEqual(Event.objects.exclude(screening__movie__director=self.director).count(), 1) + + def test_explicit_ForeignKey(self): + Package.objects.create() + screening = Screening.objects.create(movie=self.movie) + Package.objects.create(screening=screening) + + self.assertEqual(len(Package.objects.all()), 2) + self.assertEqual(len(Package.objects.select_related('screening__movie__director')), 2) + + self.assertEqual(len(Package.objects.values()), 2) + self.assertEqual(len(Package.objects.values('screening__movie__director__pk')), 2) + self.assertEqual(len(Package.objects.values('screening__movie__director__name')), 2) + self.assertEqual( + len(Package.objects.values('screening__movie__director__pk', 'screening__movie__director__name')), + 2 + ) + self.assertEqual(len(Package.objects.values('screening__movie__pk', 'screening__movie__director__pk')), 2) + self.assertEqual(len(Package.objects.values('screening__movie__pk', 'screening__movie__director__name')), 2) + self.assertEqual(len(Package.objects.values('screening__movie__title', 'screening__movie__director__pk')), 2) + self.assertEqual(len(Package.objects.values('screening__movie__title', 'screening__movie__director__name')), 2) + + self.assertEqual(Package.objects.filter(screening__movie__director=self.director).count(), 1) + self.assertEqual(Package.objects.exclude(screening__movie__director=self.director).count(), 1) diff --git a/tests/null_fk/__init__.py b/tests/null_fk/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/null_fk/models.py b/tests/null_fk/models.py new file mode 100644 index 00000000..6a7da8f6 --- /dev/null +++ b/tests/null_fk/models.py @@ -0,0 +1,57 @@ +""" +Regression tests for proper working of ForeignKey(null=True). +""" + +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +class SystemDetails(models.Model): + details = models.TextField() + + +class SystemInfo(models.Model): + system_details = models.ForeignKey(SystemDetails, models.CASCADE) + system_name = models.CharField(max_length=32) + + +class Forum(models.Model): + system_info = models.ForeignKey(SystemInfo, models.CASCADE) + forum_name = models.CharField(max_length=32) + + +@python_2_unicode_compatible +class Post(models.Model): + forum = models.ForeignKey(Forum, models.SET_NULL, null=True) + title = models.CharField(max_length=32) + + def __str__(self): + return self.title + + +@python_2_unicode_compatible +class Comment(models.Model): + post = models.ForeignKey(Post, models.SET_NULL, null=True) + comment_text = models.CharField(max_length=250) + + class Meta: + ordering = ('comment_text',) + + def __str__(self): + return self.comment_text + +# Ticket 15823 + + +class Item(models.Model): + title = models.CharField(max_length=100) + + +class PropertyValue(models.Model): + label = models.CharField(max_length=100) + + +class Property(models.Model): + item = models.ForeignKey(Item, models.CASCADE, related_name='props') + key = models.CharField(max_length=100) + value = models.ForeignKey(PropertyValue, models.SET_NULL, null=True) diff --git a/tests/null_fk/tests.py b/tests/null_fk/tests.py new file mode 100644 index 00000000..19b285e3 --- /dev/null +++ b/tests/null_fk/tests.py @@ -0,0 +1,70 @@ +from __future__ import unicode_literals + +from django.db.models import Q +from django.test import TestCase + +from .models import ( + Comment, Forum, Item, Post, PropertyValue, SystemDetails, SystemInfo, +) + + +class NullFkTests(TestCase): + + def test_null_fk(self): + d = SystemDetails.objects.create(details='First details') + s = SystemInfo.objects.create(system_name='First forum', system_details=d) + f = Forum.objects.create(system_info=s, forum_name='First forum') + p = Post.objects.create(forum=f, title='First Post') + c1 = Comment.objects.create(post=p, comment_text='My first comment') + c2 = Comment.objects.create(comment_text='My second comment') + + # Starting from comment, make sure that a .select_related(...) with a specified + # set of fields will properly LEFT JOIN multiple levels of NULLs (and the things + # that come after the NULLs, or else data that should exist won't). Regression + # test for #7369. + c = Comment.objects.select_related().get(id=c1.id) + self.assertEqual(c.post, p) + self.assertIsNone(Comment.objects.select_related().get(id=c2.id).post) + + self.assertQuerysetEqual( + Comment.objects.select_related('post__forum__system_info').all(), + [ + (c1.id, 'My first comment', ''), + (c2.id, 'My second comment', 'None') + ], + transform=lambda c: (c.id, c.comment_text, repr(c.post)) + ) + + # Regression test for #7530, #7716. + self.assertIsNone(Comment.objects.select_related('post').filter(post__isnull=True)[0].post) + + self.assertQuerysetEqual( + Comment.objects.select_related('post__forum__system_info__system_details'), + [ + (c1.id, 'My first comment', ''), + (c2.id, 'My second comment', 'None') + ], + transform=lambda c: (c.id, c.comment_text, repr(c.post)) + ) + + def test_combine_isnull(self): + item = Item.objects.create(title='Some Item') + pv = PropertyValue.objects.create(label='Some Value') + item.props.create(key='a', value=pv) + item.props.create(key='b') # value=NULL + q1 = Q(props__key='a', props__value=pv) + q2 = Q(props__key='b', props__value__isnull=True) + + # Each of these individually should return the item. + self.assertEqual(Item.objects.get(q1), item) + self.assertEqual(Item.objects.get(q2), item) + + # Logically, qs1 and qs2, and qs3 and qs4 should be the same. + qs1 = Item.objects.filter(q1) & Item.objects.filter(q2) + qs2 = Item.objects.filter(q2) & Item.objects.filter(q1) + qs3 = Item.objects.filter(q1) | Item.objects.filter(q2) + qs4 = Item.objects.filter(q2) | Item.objects.filter(q1) + + # Regression test for #15823. + self.assertEqual(list(qs1), list(qs2)) + self.assertEqual(list(qs3), list(qs4)) diff --git a/tests/null_fk_ordering/__init__.py b/tests/null_fk_ordering/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/null_fk_ordering/models.py b/tests/null_fk_ordering/models.py new file mode 100644 index 00000000..0bac5e1a --- /dev/null +++ b/tests/null_fk_ordering/models.py @@ -0,0 +1,59 @@ +""" +Regression tests for proper working of ForeignKey(null=True). Tests these bugs: + + * #7512: including a nullable foreign key reference in Meta ordering has +unexpected results + +""" +from __future__ import unicode_literals + +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +# The first two models represent a very simple null FK ordering case. +class Author(models.Model): + name = models.CharField(max_length=150) + + +@python_2_unicode_compatible +class Article(models.Model): + title = models.CharField(max_length=150) + author = models.ForeignKey(Author, models.SET_NULL, null=True) + + def __str__(self): + return 'Article titled: %s' % (self.title, ) + + class Meta: + ordering = ['author__name', ] + + +# These following 4 models represent a far more complex ordering case. +class SystemInfo(models.Model): + system_name = models.CharField(max_length=32) + + +class Forum(models.Model): + system_info = models.ForeignKey(SystemInfo, models.CASCADE) + forum_name = models.CharField(max_length=32) + + +@python_2_unicode_compatible +class Post(models.Model): + forum = models.ForeignKey(Forum, models.SET_NULL, null=True) + title = models.CharField(max_length=32) + + def __str__(self): + return self.title + + +@python_2_unicode_compatible +class Comment(models.Model): + post = models.ForeignKey(Post, models.SET_NULL, null=True) + comment_text = models.CharField(max_length=250) + + class Meta: + ordering = ['post__forum__system_info__system_name', 'comment_text'] + + def __str__(self): + return self.comment_text diff --git a/tests/null_fk_ordering/tests.py b/tests/null_fk_ordering/tests.py new file mode 100644 index 00000000..7215118b --- /dev/null +++ b/tests/null_fk_ordering/tests.py @@ -0,0 +1,42 @@ +from __future__ import unicode_literals + +from django.test import TestCase + +from .models import Article, Author, Comment, Forum, Post, SystemInfo + + +class NullFkOrderingTests(TestCase): + + def test_ordering_across_null_fk(self): + """ + Regression test for #7512 + + ordering across nullable Foreign Keys shouldn't exclude results + """ + author_1 = Author.objects.create(name='Tom Jones') + author_2 = Author.objects.create(name='Bob Smith') + Article.objects.create(title='No author on this article') + Article.objects.create(author=author_1, title='This article written by Tom Jones') + Article.objects.create(author=author_2, title='This article written by Bob Smith') + + # We can't compare results directly (since different databases sort NULLs to + # different ends of the ordering), but we can check that all results are + # returned. + self.assertEqual(len(list(Article.objects.all())), 3) + + s = SystemInfo.objects.create(system_name='System Info') + f = Forum.objects.create(system_info=s, forum_name='First forum') + p = Post.objects.create(forum=f, title='First Post') + Comment.objects.create(post=p, comment_text='My first comment') + Comment.objects.create(comment_text='My second comment') + s2 = SystemInfo.objects.create(system_name='More System Info') + f2 = Forum.objects.create(system_info=s2, forum_name='Second forum') + p2 = Post.objects.create(forum=f2, title='Second Post') + Comment.objects.create(comment_text='Another first comment') + Comment.objects.create(post=p2, comment_text='Another second comment') + + # We have to test this carefully. Some databases sort NULL values before + # everything else, some sort them afterwards. So we extract the ordered list + # and check the length. Before the fix, this list was too short (some values + # were omitted). + self.assertEqual(len(list(Comment.objects.all())), 4) diff --git a/tests/or_lookups/__init__.py b/tests/or_lookups/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/or_lookups/models.py b/tests/or_lookups/models.py new file mode 100644 index 00000000..7dea8cd4 --- /dev/null +++ b/tests/or_lookups/models.py @@ -0,0 +1,25 @@ +""" +OR lookups + +To perform an OR lookup, or a lookup that combines ANDs and ORs, combine +``QuerySet`` objects using ``&`` and ``|`` operators. + +Alternatively, use positional arguments, and pass one or more expressions of +clauses using the variable ``django.db.models.Q`` (or any object with an +``add_to_query`` method). +""" + +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +@python_2_unicode_compatible +class Article(models.Model): + headline = models.CharField(max_length=50) + pub_date = models.DateTimeField() + + class Meta: + ordering = ('pub_date',) + + def __str__(self): + return self.headline diff --git a/tests/or_lookups/tests.py b/tests/or_lookups/tests.py new file mode 100644 index 00000000..5e8b3dd6 --- /dev/null +++ b/tests/or_lookups/tests.py @@ -0,0 +1,243 @@ +# -*- encoding: utf-8 -*- +from __future__ import unicode_literals + +from datetime import datetime +from operator import attrgetter + +from django.db.models import Q +from django.test import TestCase +from django.utils.encoding import force_str + +from .models import Article + + +class OrLookupsTests(TestCase): + + def setUp(self): + self.a1 = Article.objects.create( + headline='Hello', pub_date=datetime(2005, 11, 27) + ).pk + self.a2 = Article.objects.create( + headline='Goodbye', pub_date=datetime(2005, 11, 28) + ).pk + self.a3 = Article.objects.create( + headline='Hello and goodbye', pub_date=datetime(2005, 11, 29) + ).pk + + def test_filter_or(self): + self.assertQuerysetEqual( + ( + Article.objects.filter(headline__startswith='Hello') | + Article.objects.filter(headline__startswith='Goodbye') + ), [ + 'Hello', + 'Goodbye', + 'Hello and goodbye' + ], + attrgetter("headline") + ) + + self.assertQuerysetEqual( + Article.objects.filter(headline__contains='Hello') | Article.objects.filter(headline__contains='bye'), [ + 'Hello', + 'Goodbye', + 'Hello and goodbye' + ], + attrgetter("headline") + ) + + self.assertQuerysetEqual( + Article.objects.filter(headline__iexact='Hello') | Article.objects.filter(headline__contains='ood'), [ + 'Hello', + 'Goodbye', + 'Hello and goodbye' + ], + attrgetter("headline") + ) + + self.assertQuerysetEqual( + Article.objects.filter(Q(headline__startswith='Hello') | Q(headline__startswith='Goodbye')), [ + 'Hello', + 'Goodbye', + 'Hello and goodbye' + ], + attrgetter("headline") + ) + + def test_stages(self): + # You can shorten this syntax with code like the following, which is + # especially useful if building the query in stages: + articles = Article.objects.all() + self.assertQuerysetEqual( + articles.filter(headline__startswith='Hello') & articles.filter(headline__startswith='Goodbye'), + [] + ) + self.assertQuerysetEqual( + articles.filter(headline__startswith='Hello') & articles.filter(headline__contains='bye'), [ + 'Hello and goodbye' + ], + attrgetter("headline") + ) + + def test_pk_q(self): + self.assertQuerysetEqual( + Article.objects.filter(Q(pk=self.a1) | Q(pk=self.a2)), [ + 'Hello', + 'Goodbye' + ], + attrgetter("headline") + ) + + self.assertQuerysetEqual( + Article.objects.filter(Q(pk=self.a1) | Q(pk=self.a2) | Q(pk=self.a3)), [ + 'Hello', + 'Goodbye', + 'Hello and goodbye' + ], + attrgetter("headline"), + ) + + def test_pk_in(self): + self.assertQuerysetEqual( + Article.objects.filter(pk__in=[self.a1, self.a2, self.a3]), [ + 'Hello', + 'Goodbye', + 'Hello and goodbye' + ], + attrgetter("headline"), + ) + + self.assertQuerysetEqual( + Article.objects.filter(pk__in=(self.a1, self.a2, self.a3)), [ + 'Hello', + 'Goodbye', + 'Hello and goodbye' + ], + attrgetter("headline"), + ) + + self.assertQuerysetEqual( + Article.objects.filter(pk__in=[self.a1, self.a2, self.a3, 40000]), [ + 'Hello', + 'Goodbye', + 'Hello and goodbye' + ], + attrgetter("headline"), + ) + + def test_q_repr(self): + or_expr = Q(baz=Article(headline="Foö")) + self.assertEqual(repr(or_expr), force_str("))>")) + negated_or = ~Q(baz=Article(headline="Foö")) + self.assertEqual(repr(negated_or), force_str(")))>")) + + def test_q_negated(self): + # Q objects can be negated + self.assertQuerysetEqual( + Article.objects.filter(Q(pk=self.a1) | ~Q(pk=self.a2)), [ + 'Hello', + 'Hello and goodbye' + ], + attrgetter("headline") + ) + + self.assertQuerysetEqual( + Article.objects.filter(~Q(pk=self.a1) & ~Q(pk=self.a2)), [ + 'Hello and goodbye' + ], + attrgetter("headline"), + ) + # This allows for more complex queries than filter() and exclude() + # alone would allow + self.assertQuerysetEqual( + Article.objects.filter(Q(pk=self.a1) & (~Q(pk=self.a2) | Q(pk=self.a3))), [ + 'Hello' + ], + attrgetter("headline"), + ) + + def test_complex_filter(self): + # The 'complex_filter' method supports framework features such as + # 'limit_choices_to' which normally take a single dictionary of lookup + # arguments but need to support arbitrary queries via Q objects too. + self.assertQuerysetEqual( + Article.objects.complex_filter({'pk': self.a1}), [ + 'Hello' + ], + attrgetter("headline"), + ) + + self.assertQuerysetEqual( + Article.objects.complex_filter(Q(pk=self.a1) | Q(pk=self.a2)), [ + 'Hello', + 'Goodbye' + ], + attrgetter("headline"), + ) + + def test_empty_in(self): + # Passing "in" an empty list returns no results ... + self.assertQuerysetEqual( + Article.objects.filter(pk__in=[]), + [] + ) + # ... but can return results if we OR it with another query. + self.assertQuerysetEqual( + Article.objects.filter(Q(pk__in=[]) | Q(headline__icontains='goodbye')), [ + 'Goodbye', + 'Hello and goodbye' + ], + attrgetter("headline"), + ) + + def test_q_and(self): + # Q arg objects are ANDed + self.assertQuerysetEqual( + Article.objects.filter(Q(headline__startswith='Hello'), Q(headline__contains='bye')), [ + 'Hello and goodbye' + ], + attrgetter("headline") + ) + # Q arg AND order is irrelevant + self.assertQuerysetEqual( + Article.objects.filter(Q(headline__contains='bye'), headline__startswith='Hello'), [ + 'Hello and goodbye' + ], + attrgetter("headline"), + ) + + self.assertQuerysetEqual( + Article.objects.filter(Q(headline__startswith='Hello') & Q(headline__startswith='Goodbye')), + [] + ) + + def test_q_exclude(self): + self.assertQuerysetEqual( + Article.objects.exclude(Q(headline__startswith='Hello')), [ + 'Goodbye' + ], + attrgetter("headline") + ) + + def test_other_arg_queries(self): + # Try some arg queries with operations other than filter. + self.assertEqual( + Article.objects.get(Q(headline__startswith='Hello'), Q(headline__contains='bye')).headline, + 'Hello and goodbye' + ) + + self.assertEqual( + Article.objects.filter(Q(headline__startswith='Hello') | Q(headline__contains='bye')).count(), + 3 + ) + + self.assertSequenceEqual( + Article.objects.filter(Q(headline__startswith='Hello'), Q(headline__contains='bye')).values(), [ + {"headline": "Hello and goodbye", "id": self.a3, "pub_date": datetime(2005, 11, 29)}, + ], + ) + + self.assertEqual( + Article.objects.filter(Q(headline__startswith='Hello')).in_bulk([self.a1, self.a2]), + {self.a1: Article.objects.get(pk=self.a1)} + ) diff --git a/tests/ordering/__init__.py b/tests/ordering/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/ordering/models.py b/tests/ordering/models.py new file mode 100644 index 00000000..1f794ca3 --- /dev/null +++ b/tests/ordering/models.py @@ -0,0 +1,51 @@ +""" +Specifying ordering + +Specify default ordering for a model using the ``ordering`` attribute, which +should be a list or tuple of field names. This tells Django how to order +``QuerySet`` results. + +If a field name in ``ordering`` starts with a hyphen, that field will be +ordered in descending order. Otherwise, it'll be ordered in ascending order. +The special-case field name ``"?"`` specifies random order. + +The ordering attribute is not required. If you leave it off, ordering will be +undefined -- not random, just undefined. +""" + +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +class Author(models.Model): + name = models.CharField(max_length=63, null=True, blank=True) + + class Meta: + ordering = ('-pk',) + + +@python_2_unicode_compatible +class Article(models.Model): + author = models.ForeignKey(Author, models.SET_NULL, null=True) + second_author = models.ForeignKey(Author, models.SET_NULL, null=True, related_name='+') + headline = models.CharField(max_length=100) + pub_date = models.DateTimeField() + + class Meta: + ordering = ('-pub_date', 'headline') + + def __str__(self): + return self.headline + + +class OrderedByAuthorArticle(Article): + class Meta: + proxy = True + ordering = ('author', 'second_author') + + +class Reference(models.Model): + article = models.ForeignKey(OrderedByAuthorArticle, models.CASCADE) + + class Meta: + ordering = ('article',) diff --git a/tests/ordering/tests.py b/tests/ordering/tests.py new file mode 100644 index 00000000..cfdb3bbf --- /dev/null +++ b/tests/ordering/tests.py @@ -0,0 +1,369 @@ +from __future__ import unicode_literals + +from datetime import datetime +from operator import attrgetter + +import django +from django.db.models import F +from django.db.models.functions import Upper +from django.test import TestCase + +from .models import Article, Author, Reference + + +class OrderingTests(TestCase): + + @classmethod + def setUpTestData(cls): + cls.a1 = Article.objects.create(headline="Article 1", pub_date=datetime(2005, 7, 26)) + cls.a2 = Article.objects.create(headline="Article 2", pub_date=datetime(2005, 7, 27)) + cls.a3 = Article.objects.create(headline="Article 3", pub_date=datetime(2005, 7, 27)) + cls.a4 = Article.objects.create(headline="Article 4", pub_date=datetime(2005, 7, 28)) + cls.author_1 = Author.objects.create(name="Name 1") + cls.author_2 = Author.objects.create(name="Name 2") + for i in range(2): + Author.objects.create() + + def test_default_ordering(self): + """ + By default, Article.objects.all() orders by pub_date descending, then + headline ascending. + """ + self.assertQuerysetEqual( + Article.objects.all(), [ + "Article 4", + "Article 2", + "Article 3", + "Article 1", + ], + attrgetter("headline") + ) + + # Getting a single item should work too: + self.assertEqual(Article.objects.all()[0], self.a4) + + def test_default_ordering_override(self): + """ + Override ordering with order_by, which is in the same format as the + ordering attribute in models. + """ + self.assertQuerysetEqual( + Article.objects.order_by("headline"), [ + "Article 1", + "Article 2", + "Article 3", + "Article 4", + ], + attrgetter("headline") + ) + self.assertQuerysetEqual( + Article.objects.order_by("pub_date", "-headline"), [ + "Article 1", + "Article 3", + "Article 2", + "Article 4", + ], + attrgetter("headline") + ) + + def test_order_by_override(self): + """ + Only the last order_by has any effect (since they each override any + previous ordering). + """ + self.assertQuerysetEqual( + Article.objects.order_by("id"), [ + "Article 1", + "Article 2", + "Article 3", + "Article 4", + ], + attrgetter("headline") + ) + self.assertQuerysetEqual( + Article.objects.order_by("id").order_by("-headline"), [ + "Article 4", + "Article 3", + "Article 2", + "Article 1", + ], + attrgetter("headline") + ) + + def test_order_by_nulls_first_and_last(self): + if django.VERSION < (1, 11, 0): + self.skipTest("Only run this on Django 1.11 or newer") + msg = "nulls_first and nulls_last are mutually exclusive" + with self.assertRaisesMessage(ValueError, msg): + Article.objects.order_by(F("author").desc(nulls_last=True, nulls_first=True)) + + def test_order_by_nulls_last(self): + self.skipTest("TODO fix django.db.utils.ProgrammingError: Incorrect syntax near 'NULLS'.") + Article.objects.filter(headline="Article 3").update(author=self.author_1) + Article.objects.filter(headline="Article 4").update(author=self.author_2) + # asc and desc are chainable with nulls_last. + self.assertSequenceEqual( + Article.objects.order_by(F("author").desc(nulls_last=True)), + [self.a4, self.a3, self.a1, self.a2], + ) + self.assertSequenceEqual( + Article.objects.order_by(F("author").asc(nulls_last=True)), + [self.a3, self.a4, self.a1, self.a2], + ) + self.assertSequenceEqual( + Article.objects.order_by(Upper("author__name").desc(nulls_last=True)), + [self.a4, self.a3, self.a1, self.a2], + ) + self.assertSequenceEqual( + Article.objects.order_by(Upper("author__name").asc(nulls_last=True)), + [self.a3, self.a4, self.a1, self.a2], + ) + + def test_order_by_nulls_first(self): + self.skipTest("TODO fix django.db.utils.ProgrammingError: Incorrect syntax near 'NULLS'.") + Article.objects.filter(headline="Article 3").update(author=self.author_1) + Article.objects.filter(headline="Article 4").update(author=self.author_2) + # asc and desc are chainable with nulls_first. + self.assertSequenceEqual( + Article.objects.order_by(F("author").asc(nulls_first=True)), + [self.a1, self.a2, self.a3, self.a4], + ) + self.assertSequenceEqual( + Article.objects.order_by(F("author").desc(nulls_first=True)), + [self.a1, self.a2, self.a4, self.a3], + ) + self.assertSequenceEqual( + Article.objects.order_by(Upper("author__name").asc(nulls_first=True)), + [self.a1, self.a2, self.a3, self.a4], + ) + self.assertSequenceEqual( + Article.objects.order_by(Upper("author__name").desc(nulls_first=True)), + [self.a1, self.a2, self.a4, self.a3], + ) + + def test_stop_slicing(self): + """ + Use the 'stop' part of slicing notation to limit the results. + """ + self.assertQuerysetEqual( + Article.objects.order_by("headline")[:2], [ + "Article 1", + "Article 2", + ], + attrgetter("headline") + ) + + def test_stop_start_slicing(self): + """ + Use the 'stop' and 'start' parts of slicing notation to offset the + result list. + """ + self.assertQuerysetEqual( + Article.objects.order_by("headline")[1:3], [ + "Article 2", + "Article 3", + ], + attrgetter("headline") + ) + + def test_random_ordering(self): + """ + Use '?' to order randomly. + """ + self.assertEqual( + len(list(Article.objects.order_by("?"))), 4 + ) + + def test_reversed_ordering(self): + """ + Ordering can be reversed using the reverse() method on a queryset. + This allows you to extract things like "the last two items" (reverse + and then take the first two). + """ + self.assertQuerysetEqual( + Article.objects.all().reverse()[:2], [ + "Article 1", + "Article 3", + ], + attrgetter("headline") + ) + + def test_reverse_ordering_pure(self): + qs1 = Article.objects.order_by(F('headline').asc()) + qs2 = qs1.reverse() + self.assertQuerysetEqual( + qs1, [ + "Article 1", + "Article 2", + "Article 3", + "Article 4", + ], + attrgetter("headline") + ) + self.assertQuerysetEqual( + qs2, [ + "Article 4", + "Article 3", + "Article 2", + "Article 1", + ], + attrgetter("headline") + ) + + def test_extra_ordering(self): + """ + Ordering can be based on fields included from an 'extra' clause + """ + self.assertQuerysetEqual( + Article.objects.extra(select={"foo": "pub_date"}, order_by=["foo", "headline"]), [ + "Article 1", + "Article 2", + "Article 3", + "Article 4", + ], + attrgetter("headline") + ) + + def test_extra_ordering_quoting(self): + """ + If the extra clause uses an SQL keyword for a name, it will be + protected by quoting. + """ + self.assertQuerysetEqual( + Article.objects.extra(select={"order": "pub_date"}, order_by=["order", "headline"]), [ + "Article 1", + "Article 2", + "Article 3", + "Article 4", + ], + attrgetter("headline") + ) + + def test_extra_ordering_with_table_name(self): + self.assertQuerysetEqual( + Article.objects.extra(order_by=['ordering_article.headline']), [ + "Article 1", + "Article 2", + "Article 3", + "Article 4", + ], + attrgetter("headline") + ) + self.assertQuerysetEqual( + Article.objects.extra(order_by=['-ordering_article.headline']), [ + "Article 4", + "Article 3", + "Article 2", + "Article 1", + ], + attrgetter("headline") + ) + + def test_order_by_pk(self): + """ + 'pk' works as an ordering option in Meta. + """ + self.assertQuerysetEqual( + Author.objects.all(), + list(reversed(range(1, Author.objects.count() + 1))), + attrgetter("pk"), + ) + + def test_order_by_fk_attname(self): + """ + ordering by a foreign key by its attribute name prevents the query + from inheriting its related model ordering option (#19195). + """ + for i in range(1, 5): + author = Author.objects.get(pk=i) + article = getattr(self, "a%d" % (5 - i)) + article.author = author + article.save(update_fields={'author'}) + + self.assertQuerysetEqual( + Article.objects.order_by('author_id'), [ + "Article 4", + "Article 3", + "Article 2", + "Article 1", + ], + attrgetter("headline") + ) + + def test_order_by_f_expression(self): + self.assertQuerysetEqual( + Article.objects.order_by(F('headline')), [ + "Article 1", + "Article 2", + "Article 3", + "Article 4", + ], + attrgetter("headline") + ) + self.assertQuerysetEqual( + Article.objects.order_by(F('headline').asc()), [ + "Article 1", + "Article 2", + "Article 3", + "Article 4", + ], + attrgetter("headline") + ) + self.assertQuerysetEqual( + Article.objects.order_by(F('headline').desc()), [ + "Article 4", + "Article 3", + "Article 2", + "Article 1", + ], + attrgetter("headline") + ) + + def test_order_by_f_expression_duplicates(self): + """ + A column may only be included once (the first occurrence) so we check + to ensure there are no duplicates by inspecting the SQL. + """ + qs = Article.objects.order_by(F('headline').asc(), F('headline').desc()) + sql = str(qs.query).upper() + fragment = sql[sql.find('ORDER BY'):] + self.assertEqual(fragment.count('HEADLINE'), 1) + self.assertQuerysetEqual( + qs, [ + "Article 1", + "Article 2", + "Article 3", + "Article 4", + ], + attrgetter("headline") + ) + qs = Article.objects.order_by(F('headline').desc(), F('headline').asc()) + sql = str(qs.query).upper() + fragment = sql[sql.find('ORDER BY'):] + self.assertEqual(fragment.count('HEADLINE'), 1) + self.assertQuerysetEqual( + qs, [ + "Article 4", + "Article 3", + "Article 2", + "Article 1", + ], + attrgetter("headline") + ) + + def test_related_ordering_duplicate_table_reference(self): + """ + An ordering referencing a model with an ordering referencing a model + multiple time no circular reference should be detected (#24654). + """ + first_author = Author.objects.create() + second_author = Author.objects.create() + self.a1.author = first_author + self.a1.second_author = second_author + self.a1.save() + self.a2.author = second_author + self.a2.second_author = first_author + self.a2.save() + r1 = Reference.objects.create(article_id=self.a1.pk) + r2 = Reference.objects.create(article_id=self.a2.pk) + self.assertSequenceEqual(Reference.objects.all(), [r2, r1]) diff --git a/tests/pagination/__init__.py b/tests/pagination/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/pagination/custom.py b/tests/pagination/custom.py new file mode 100644 index 00000000..77277dca --- /dev/null +++ b/tests/pagination/custom.py @@ -0,0 +1,20 @@ +from django.core.paginator import Page, Paginator + + +class ValidAdjacentNumsPage(Page): + + def next_page_number(self): + if not self.has_next(): + return None + return super(ValidAdjacentNumsPage, self).next_page_number() + + def previous_page_number(self): + if not self.has_previous(): + return None + return super(ValidAdjacentNumsPage, self).previous_page_number() + + +class ValidAdjacentNumsPaginator(Paginator): + + def _get_page(self, *args, **kwargs): + return ValidAdjacentNumsPage(*args, **kwargs) diff --git a/tests/pagination/models.py b/tests/pagination/models.py new file mode 100644 index 00000000..9dc8d4b7 --- /dev/null +++ b/tests/pagination/models.py @@ -0,0 +1,11 @@ +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +@python_2_unicode_compatible +class Article(models.Model): + headline = models.CharField(max_length=100, default='Default headline') + pub_date = models.DateTimeField() + + def __str__(self): + return self.headline diff --git a/tests/pagination/tests.py b/tests/pagination/tests.py new file mode 100644 index 00000000..2572dbe6 --- /dev/null +++ b/tests/pagination/tests.py @@ -0,0 +1,362 @@ +from __future__ import unicode_literals + +import unittest +import warnings +from datetime import datetime + +import django +from django.core.paginator import ( + EmptyPage, InvalidPage, PageNotAnInteger, Paginator, +) +if django.VERSION >= (1, 11, 0): + from django.core.paginator import UnorderedObjectListWarning + +from django.test import TestCase +from django.utils import six + +from .custom import ValidAdjacentNumsPaginator +from .models import Article + + +class PaginationTests(unittest.TestCase): + """ + Tests for the Paginator and Page classes. + """ + + def check_paginator(self, params, output): + """ + Helper method that instantiates a Paginator object from the passed + params and then checks that its attributes match the passed output. + """ + count, num_pages, page_range = output + paginator = Paginator(*params) + self.check_attribute('count', paginator, count, params) + self.check_attribute('num_pages', paginator, num_pages, params) + self.check_attribute('page_range', paginator, page_range, params, coerce=list) + + def check_attribute(self, name, paginator, expected, params, coerce=None): + """ + Helper method that checks a single attribute and gives a nice error + message upon test failure. + """ + got = getattr(paginator, name) + if coerce is not None: + got = coerce(got) + self.assertEqual( + expected, got, + "For '%s', expected %s but got %s. Paginator parameters were: %s" + % (name, expected, got, params) + ) + + def test_paginator(self): + """ + Tests the paginator attributes using varying inputs. + """ + nine = [1, 2, 3, 4, 5, 6, 7, 8, 9] + ten = nine + [10] + eleven = ten + [11] + tests = ( + # Each item is two tuples: + # First tuple is Paginator parameters - object_list, per_page, + # orphans, and allow_empty_first_page. + # Second tuple is resulting Paginator attributes - count, + # num_pages, and page_range. + # Ten items, varying orphans, no empty first page. + ((ten, 4, 0, False), (10, 3, [1, 2, 3])), + ((ten, 4, 1, False), (10, 3, [1, 2, 3])), + ((ten, 4, 2, False), (10, 2, [1, 2])), + ((ten, 4, 5, False), (10, 2, [1, 2])), + ((ten, 4, 6, False), (10, 1, [1])), + # Ten items, varying orphans, allow empty first page. + ((ten, 4, 0, True), (10, 3, [1, 2, 3])), + ((ten, 4, 1, True), (10, 3, [1, 2, 3])), + ((ten, 4, 2, True), (10, 2, [1, 2])), + ((ten, 4, 5, True), (10, 2, [1, 2])), + ((ten, 4, 6, True), (10, 1, [1])), + # One item, varying orphans, no empty first page. + (([1], 4, 0, False), (1, 1, [1])), + (([1], 4, 1, False), (1, 1, [1])), + (([1], 4, 2, False), (1, 1, [1])), + # One item, varying orphans, allow empty first page. + (([1], 4, 0, True), (1, 1, [1])), + (([1], 4, 1, True), (1, 1, [1])), + (([1], 4, 2, True), (1, 1, [1])), + # Zero items, varying orphans, no empty first page. + (([], 4, 0, False), (0, 0, [])), + (([], 4, 1, False), (0, 0, [])), + (([], 4, 2, False), (0, 0, [])), + # Zero items, varying orphans, allow empty first page. + (([], 4, 0, True), (0, 1, [1])), + (([], 4, 1, True), (0, 1, [1])), + (([], 4, 2, True), (0, 1, [1])), + # Number if items one less than per_page. + (([], 1, 0, True), (0, 1, [1])), + (([], 1, 0, False), (0, 0, [])), + (([1], 2, 0, True), (1, 1, [1])), + ((nine, 10, 0, True), (9, 1, [1])), + # Number if items equal to per_page. + (([1], 1, 0, True), (1, 1, [1])), + (([1, 2], 2, 0, True), (2, 1, [1])), + ((ten, 10, 0, True), (10, 1, [1])), + # Number if items one more than per_page. + (([1, 2], 1, 0, True), (2, 2, [1, 2])), + (([1, 2, 3], 2, 0, True), (3, 2, [1, 2])), + ((eleven, 10, 0, True), (11, 2, [1, 2])), + # Number if items one more than per_page with one orphan. + (([1, 2], 1, 1, True), (2, 1, [1])), + (([1, 2, 3], 2, 1, True), (3, 1, [1])), + ((eleven, 10, 1, True), (11, 1, [1])), + # Non-integer inputs + ((ten, '4', 1, False), (10, 3, [1, 2, 3])), + ((ten, '4', 1, False), (10, 3, [1, 2, 3])), + ((ten, 4, '1', False), (10, 3, [1, 2, 3])), + ((ten, 4, '1', False), (10, 3, [1, 2, 3])), + ) + for params, output in tests: + self.check_paginator(params, output) + + def test_invalid_page_number(self): + """ + Invalid page numbers result in the correct exception being raised. + """ + paginator = Paginator([1, 2, 3], 2) + with self.assertRaises(InvalidPage): + paginator.page(3) + with self.assertRaises(PageNotAnInteger): + paginator.validate_number(None) + with self.assertRaises(PageNotAnInteger): + paginator.validate_number('x') + # With no content and allow_empty_first_page=True, 1 is a valid page number + paginator = Paginator([], 2) + self.assertEqual(paginator.validate_number(1), 1) + + def test_paginate_misc_classes(self): + class CountContainer(object): + def count(self): + return 42 + # Paginator can be passed other objects with a count() method. + paginator = Paginator(CountContainer(), 10) + self.assertEqual(42, paginator.count) + self.assertEqual(5, paginator.num_pages) + self.assertEqual([1, 2, 3, 4, 5], list(paginator.page_range)) + + # Paginator can be passed other objects that implement __len__. + class LenContainer(object): + def __len__(self): + return 42 + paginator = Paginator(LenContainer(), 10) + self.assertEqual(42, paginator.count) + self.assertEqual(5, paginator.num_pages) + self.assertEqual([1, 2, 3, 4, 5], list(paginator.page_range)) + + def check_indexes(self, params, page_num, indexes): + """ + Helper method that instantiates a Paginator object from the passed + params and then checks that the start and end indexes of the passed + page_num match those given as a 2-tuple in indexes. + """ + paginator = Paginator(*params) + if page_num == 'first': + page_num = 1 + elif page_num == 'last': + page_num = paginator.num_pages + page = paginator.page(page_num) + start, end = indexes + msg = ("For %s of page %s, expected %s but got %s. Paginator parameters were: %s") + self.assertEqual(start, page.start_index(), msg % ('start index', page_num, start, page.start_index(), params)) + self.assertEqual(end, page.end_index(), msg % ('end index', page_num, end, page.end_index(), params)) + + def test_page_indexes(self): + """ + Paginator pages have the correct start and end indexes. + """ + ten = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + tests = ( + # Each item is three tuples: + # First tuple is Paginator parameters - object_list, per_page, + # orphans, and allow_empty_first_page. + # Second tuple is the start and end indexes of the first page. + # Third tuple is the start and end indexes of the last page. + # Ten items, varying per_page, no orphans. + ((ten, 1, 0, True), (1, 1), (10, 10)), + ((ten, 2, 0, True), (1, 2), (9, 10)), + ((ten, 3, 0, True), (1, 3), (10, 10)), + ((ten, 5, 0, True), (1, 5), (6, 10)), + # Ten items, varying per_page, with orphans. + ((ten, 1, 1, True), (1, 1), (9, 10)), + ((ten, 1, 2, True), (1, 1), (8, 10)), + ((ten, 3, 1, True), (1, 3), (7, 10)), + ((ten, 3, 2, True), (1, 3), (7, 10)), + ((ten, 3, 4, True), (1, 3), (4, 10)), + ((ten, 5, 1, True), (1, 5), (6, 10)), + ((ten, 5, 2, True), (1, 5), (6, 10)), + ((ten, 5, 5, True), (1, 10), (1, 10)), + # One item, varying orphans, no empty first page. + (([1], 4, 0, False), (1, 1), (1, 1)), + (([1], 4, 1, False), (1, 1), (1, 1)), + (([1], 4, 2, False), (1, 1), (1, 1)), + # One item, varying orphans, allow empty first page. + (([1], 4, 0, True), (1, 1), (1, 1)), + (([1], 4, 1, True), (1, 1), (1, 1)), + (([1], 4, 2, True), (1, 1), (1, 1)), + # Zero items, varying orphans, allow empty first page. + (([], 4, 0, True), (0, 0), (0, 0)), + (([], 4, 1, True), (0, 0), (0, 0)), + (([], 4, 2, True), (0, 0), (0, 0)), + ) + for params, first, last in tests: + self.check_indexes(params, 'first', first) + self.check_indexes(params, 'last', last) + + # When no items and no empty first page, we should get EmptyPage error. + with self.assertRaises(EmptyPage): + self.check_indexes(([], 4, 0, False), 1, None) + with self.assertRaises(EmptyPage): + self.check_indexes(([], 4, 1, False), 1, None) + with self.assertRaises(EmptyPage): + self.check_indexes(([], 4, 2, False), 1, None) + + def test_page_sequence(self): + """ + A paginator page acts like a standard sequence. + """ + eleven = 'abcdefghijk' + page2 = Paginator(eleven, per_page=5, orphans=1).page(2) + self.assertEqual(len(page2), 6) + self.assertIn('k', page2) + self.assertNotIn('a', page2) + self.assertEqual(''.join(page2), 'fghijk') + self.assertEqual(''.join(reversed(page2)), 'kjihgf') + + def test_get_page_hook(self): + """ + A Paginator subclass can use the ``_get_page`` hook to + return an alternative to the standard Page class. + """ + eleven = 'abcdefghijk' + paginator = ValidAdjacentNumsPaginator(eleven, per_page=6) + page1 = paginator.page(1) + page2 = paginator.page(2) + self.assertIsNone(page1.previous_page_number()) + self.assertEqual(page1.next_page_number(), 2) + self.assertEqual(page2.previous_page_number(), 1) + self.assertIsNone(page2.next_page_number()) + + def test_page_range_iterator(self): + """ + Paginator.page_range should be an iterator. + """ + self.assertIsInstance(Paginator([1, 2, 3], 2).page_range, type(six.moves.range(0))) + + +class ModelPaginationTests(TestCase): + """ + Test pagination with Django model instances + """ + def setUp(self): + # Prepare a list of objects for pagination. + for x in range(1, 10): + a = Article(headline='Article %s' % x, pub_date=datetime(2005, 7, 29)) + a.save() + + def test_first_page(self): + paginator = Paginator(Article.objects.order_by('id'), 5) + p = paginator.page(1) + self.assertEqual("", six.text_type(p)) + self.assertQuerysetEqual(p.object_list, [ + "", + "", + "", + "", + "" + ]) + self.assertTrue(p.has_next()) + self.assertFalse(p.has_previous()) + self.assertTrue(p.has_other_pages()) + self.assertEqual(2, p.next_page_number()) + with self.assertRaises(InvalidPage): + p.previous_page_number() + self.assertEqual(1, p.start_index()) + self.assertEqual(5, p.end_index()) + + def test_last_page(self): + paginator = Paginator(Article.objects.order_by('id'), 5) + p = paginator.page(2) + self.assertEqual("", six.text_type(p)) + self.assertQuerysetEqual(p.object_list, [ + "", + "", + "", + "" + ]) + self.assertFalse(p.has_next()) + self.assertTrue(p.has_previous()) + self.assertTrue(p.has_other_pages()) + with self.assertRaises(InvalidPage): + p.next_page_number() + self.assertEqual(1, p.previous_page_number()) + self.assertEqual(6, p.start_index()) + self.assertEqual(9, p.end_index()) + + def test_page_getitem(self): + """ + Tests proper behavior of a paginator page __getitem__ (queryset + evaluation, slicing, exception raised). + """ + paginator = Paginator(Article.objects.order_by('id'), 5) + p = paginator.page(1) + + # Make sure object_list queryset is not evaluated by an invalid __getitem__ call. + # (this happens from the template engine when using eg: {% page_obj.has_previous %}) + self.assertIsNone(p.object_list._result_cache) + with self.assertRaises(TypeError): + p['has_previous'] + self.assertIsNone(p.object_list._result_cache) + self.assertNotIsInstance(p.object_list, list) + + # Make sure slicing the Page object with numbers and slice objects work. + self.assertEqual(p[0], Article.objects.get(headline='Article 1')) + self.assertQuerysetEqual(p[slice(2)], [ + "", + "", + ] + ) + # After __getitem__ is called, object_list is a list + self.assertIsInstance(p.object_list, list) + + def test_paginating_unordered_queryset_raises_warning(self): + if django.VERSION < (1, 11, 0): + self.skipTest("does not work on older version of Django") + with warnings.catch_warnings(record=True) as warns: + # Prevent the RuntimeWarning subclass from appearing as an + # exception due to the warnings.simplefilter() in runtests.py. + warnings.filterwarnings('always', category=UnorderedObjectListWarning) + Paginator(Article.objects.all(), 5) + self.assertEqual(len(warns), 1) + warning = warns[0] + self.assertEqual(str(warning.message), ( + "Pagination may yield inconsistent results with an unordered " + "object_list: QuerySet." + )) + # The warning points at the Paginator caller (i.e. the stacklevel + # is appropriate). + self.assertEqual(warning.filename, __file__) + + def test_paginating_unordered_object_list_raises_warning(self): + """ + Unordered object list warning with an object that has an orderd + attribute but not a model attribute. + """ + if django.VERSION < (1, 11, 0): + self.skipTest("does not work on older version of Django") + class ObjectList(): + ordered = False + object_list = ObjectList() + with warnings.catch_warnings(record=True) as warns: + warnings.filterwarnings('always', category=UnorderedObjectListWarning) + Paginator(object_list, 5) + self.assertEqual(len(warns), 1) + self.assertEqual(str(warns[0].message), ( + "Pagination may yield inconsistent results with an unordered " + "object_list: {!r}.".format(object_list) + )) diff --git a/tests/queries/__init__.py b/tests/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/queries/models.py b/tests/queries/models.py new file mode 100644 index 00000000..587d2e68 --- /dev/null +++ b/tests/queries/models.py @@ -0,0 +1,720 @@ +""" +Various complex queries that have been problematic in the past. +""" +import threading + +from django.db import models + + +class DumbCategory(models.Model): + pass + + +class ProxyCategory(DumbCategory): + class Meta: + proxy = True + + +class NamedCategory(DumbCategory): + name = models.CharField(max_length=10) + + def __str__(self): + return self.name + + +class Tag(models.Model): + name = models.CharField(max_length=10) + parent = models.ForeignKey( + 'self', + models.SET_NULL, + blank=True, null=True, + related_name='children', + ) + category = models.ForeignKey(NamedCategory, models.SET_NULL, null=True, default=None) + + class Meta: + ordering = ['name'] + + def __str__(self): + return self.name + + +class Note(models.Model): + note = models.CharField(max_length=100) + misc = models.CharField(max_length=10) + tag = models.ForeignKey(Tag, models.SET_NULL, blank=True, null=True) + + class Meta: + ordering = ['note'] + + def __str__(self): + return self.note + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Regression for #13227 -- having an attribute that + # is unpicklable doesn't stop you from cloning queries + # that use objects of that type as an argument. + self.lock = threading.Lock() + + +class Annotation(models.Model): + name = models.CharField(max_length=10) + tag = models.ForeignKey(Tag, models.CASCADE) + notes = models.ManyToManyField(Note) + + def __str__(self): + return self.name + + +class ExtraInfo(models.Model): + info = models.CharField(max_length=100) + note = models.ForeignKey(Note, models.CASCADE, null=True) + value = models.IntegerField(null=True) + + class Meta: + ordering = ['info'] + + def __str__(self): + return self.info + + +class Author(models.Model): + name = models.CharField(max_length=10) + num = models.IntegerField(unique=True) + extra = models.ForeignKey(ExtraInfo, models.CASCADE) + + class Meta: + ordering = ['name'] + + def __str__(self): + return self.name + + +class Item(models.Model): + name = models.CharField(max_length=10) + created = models.DateTimeField() + modified = models.DateTimeField(blank=True, null=True) + tags = models.ManyToManyField(Tag, blank=True) + creator = models.ForeignKey(Author, models.CASCADE) + note = models.ForeignKey(Note, models.CASCADE) + + class Meta: + ordering = ['-note', 'name'] + + def __str__(self): + return self.name + + +class Report(models.Model): + name = models.CharField(max_length=10) + creator = models.ForeignKey(Author, models.SET_NULL, to_field='num', null=True) + + def __str__(self): + return self.name + + +class ReportComment(models.Model): + report = models.ForeignKey(Report, models.CASCADE) + + +class Ranking(models.Model): + rank = models.IntegerField() + author = models.ForeignKey(Author, models.CASCADE) + + class Meta: + # A complex ordering specification. Should stress the system a bit. + ordering = ('author__extra__note', 'author__name', 'rank') + + def __str__(self): + return '%d: %s' % (self.rank, self.author.name) + + +class Cover(models.Model): + title = models.CharField(max_length=50) + item = models.ForeignKey(Item, models.CASCADE) + + class Meta: + ordering = ['item'] + + def __str__(self): + return self.title + + +class Number(models.Model): + num = models.IntegerField() + + def __str__(self): + return str(self.num) + +# Symmetrical m2m field with a normal field using the reverse accessor name +# ("valid"). + + +class Valid(models.Model): + valid = models.CharField(max_length=10) + parent = models.ManyToManyField('self') + + class Meta: + ordering = ['valid'] + +# Some funky cross-linked models for testing a couple of infinite recursion +# cases. + + +class X(models.Model): + y = models.ForeignKey('Y', models.CASCADE) + + +class Y(models.Model): + x1 = models.ForeignKey(X, models.CASCADE, related_name='y1') + +# Some models with a cycle in the default ordering. This would be bad if we +# didn't catch the infinite loop. + + +class LoopX(models.Model): + y = models.ForeignKey('LoopY', models.CASCADE) + + class Meta: + ordering = ['y'] + + +class LoopY(models.Model): + x = models.ForeignKey(LoopX, models.CASCADE) + + class Meta: + ordering = ['x'] + + +class LoopZ(models.Model): + z = models.ForeignKey('self', models.CASCADE) + + class Meta: + ordering = ['z'] + + +# A model and custom default manager combination. + + +class CustomManager(models.Manager): + def get_queryset(self): + qs = super().get_queryset() + return qs.filter(public=True, tag__name='t1') + + +class ManagedModel(models.Model): + data = models.CharField(max_length=10) + tag = models.ForeignKey(Tag, models.CASCADE) + public = models.BooleanField(default=True) + + objects = CustomManager() + normal_manager = models.Manager() + + def __str__(self): + return self.data + +# An inter-related setup with multiple paths from Child to Detail. + + +class Detail(models.Model): + data = models.CharField(max_length=10) + + +class MemberManager(models.Manager): + def get_queryset(self): + return super().get_queryset().select_related("details") + + +class Member(models.Model): + name = models.CharField(max_length=10) + details = models.OneToOneField(Detail, models.CASCADE, primary_key=True) + + objects = MemberManager() + + +class Child(models.Model): + person = models.OneToOneField(Member, models.CASCADE, primary_key=True) + parent = models.ForeignKey(Member, models.CASCADE, related_name="children") + +# Custom primary keys interfered with ordering in the past. + + +class CustomPk(models.Model): + name = models.CharField(max_length=10, primary_key=True) + extra = models.CharField(max_length=10) + + class Meta: + ordering = ['name', 'extra'] + + +class Related(models.Model): + custom = models.ForeignKey(CustomPk, models.CASCADE, null=True) + + +class CustomPkTag(models.Model): + id = models.CharField(max_length=20, primary_key=True) + custom_pk = models.ManyToManyField(CustomPk) + tag = models.CharField(max_length=20) + +# An inter-related setup with a model subclass that has a nullable +# path to another model, and a return path from that model. + + +class Celebrity(models.Model): + name = models.CharField("Name", max_length=20) + greatest_fan = models.ForeignKey("Fan", models.SET_NULL, null=True, unique=True) + + def __str__(self): + return self.name + + +class TvChef(Celebrity): + pass + + +class Fan(models.Model): + fan_of = models.ForeignKey(Celebrity, models.CASCADE) + +# Multiple foreign keys + + +class LeafA(models.Model): + data = models.CharField(max_length=10) + + def __str__(self): + return self.data + + +class LeafB(models.Model): + data = models.CharField(max_length=10) + + +class Join(models.Model): + a = models.ForeignKey(LeafA, models.CASCADE) + b = models.ForeignKey(LeafB, models.CASCADE) + + +class ReservedName(models.Model): + name = models.CharField(max_length=20) + order = models.IntegerField() + + def __str__(self): + return self.name + +# A simpler shared-foreign-key setup that can expose some problems. + + +class SharedConnection(models.Model): + data = models.CharField(max_length=10) + + def __str__(self): + return self.data + + +class PointerA(models.Model): + connection = models.ForeignKey(SharedConnection, models.CASCADE) + + +class PointerB(models.Model): + connection = models.ForeignKey(SharedConnection, models.CASCADE) + +# Multi-layer ordering + + +class SingleObject(models.Model): + name = models.CharField(max_length=10) + + class Meta: + ordering = ['name'] + + def __str__(self): + return self.name + + +class RelatedObject(models.Model): + single = models.ForeignKey(SingleObject, models.SET_NULL, null=True) + f = models.IntegerField(null=True) + + class Meta: + ordering = ['single'] + + +class Plaything(models.Model): + name = models.CharField(max_length=10) + others = models.ForeignKey(RelatedObject, models.SET_NULL, null=True) + + class Meta: + ordering = ['others'] + + def __str__(self): + return self.name + + +class Article(models.Model): + name = models.CharField(max_length=20) + created = models.DateTimeField() + + def __str__(self): + return self.name + + +class Food(models.Model): + name = models.CharField(max_length=20, unique=True) + + def __str__(self): + return self.name + + +class Eaten(models.Model): + food = models.ForeignKey(Food, models.SET_NULL, to_field="name", null=True) + meal = models.CharField(max_length=20) + + def __str__(self): + return "%s at %s" % (self.food, self.meal) + + +class Node(models.Model): + num = models.IntegerField(unique=True) + parent = models.ForeignKey("self", models.SET_NULL, to_field="num", null=True) + + def __str__(self): + return "%s" % self.num + +# Bug #12252 + + +class ObjectA(models.Model): + name = models.CharField(max_length=50) + + def __str__(self): + return self.name + + def __iter__(self): + # Ticket #23721 + assert False, 'type checking should happen without calling model __iter__' + + +class ProxyObjectA(ObjectA): + class Meta: + proxy = True + + +class ChildObjectA(ObjectA): + pass + + +class ObjectB(models.Model): + name = models.CharField(max_length=50) + objecta = models.ForeignKey(ObjectA, models.CASCADE) + num = models.PositiveSmallIntegerField() + + def __str__(self): + return self.name + + +class ProxyObjectB(ObjectB): + class Meta: + proxy = True + + +class ObjectC(models.Model): + name = models.CharField(max_length=50) + objecta = models.ForeignKey(ObjectA, models.SET_NULL, null=True) + objectb = models.ForeignKey(ObjectB, models.SET_NULL, null=True) + childobjecta = models.ForeignKey(ChildObjectA, models.SET_NULL, null=True, related_name='ca_pk') + + def __str__(self): + return self.name + + +class SimpleCategory(models.Model): + name = models.CharField(max_length=15) + + def __str__(self): + return self.name + + +class SpecialCategory(SimpleCategory): + special_name = models.CharField(max_length=15) + + def __str__(self): + return self.name + " " + self.special_name + + +class CategoryItem(models.Model): + category = models.ForeignKey(SimpleCategory, models.CASCADE) + + def __str__(self): + return "category item: " + str(self.category) + + +class MixedCaseFieldCategoryItem(models.Model): + CaTeGoRy = models.ForeignKey(SimpleCategory, models.CASCADE) + + +class MixedCaseDbColumnCategoryItem(models.Model): + category = models.ForeignKey(SimpleCategory, models.CASCADE, db_column='CaTeGoRy_Id') + + +class OneToOneCategory(models.Model): + new_name = models.CharField(max_length=15) + category = models.OneToOneField(SimpleCategory, models.CASCADE) + + def __str__(self): + return "one2one " + self.new_name + + +class CategoryRelationship(models.Model): + first = models.ForeignKey(SimpleCategory, models.CASCADE, related_name='first_rel') + second = models.ForeignKey(SimpleCategory, models.CASCADE, related_name='second_rel') + + +class CommonMixedCaseForeignKeys(models.Model): + category = models.ForeignKey(CategoryItem, models.CASCADE) + mixed_case_field_category = models.ForeignKey(MixedCaseFieldCategoryItem, models.CASCADE) + mixed_case_db_column_category = models.ForeignKey(MixedCaseDbColumnCategoryItem, models.CASCADE) + + +class NullableName(models.Model): + name = models.CharField(max_length=20, null=True) + + class Meta: + ordering = ['id'] + + +class ModelD(models.Model): + name = models.TextField() + + +class ModelC(models.Model): + name = models.TextField() + + +class ModelB(models.Model): + name = models.TextField() + c = models.ForeignKey(ModelC, models.CASCADE) + + +class ModelA(models.Model): + name = models.TextField() + b = models.ForeignKey(ModelB, models.SET_NULL, null=True) + d = models.ForeignKey(ModelD, models.CASCADE) + + +class Job(models.Model): + name = models.CharField(max_length=20, unique=True) + + def __str__(self): + return self.name + + +class JobResponsibilities(models.Model): + job = models.ForeignKey(Job, models.CASCADE, to_field='name') + responsibility = models.ForeignKey('Responsibility', models.CASCADE, to_field='description') + + +class Responsibility(models.Model): + description = models.CharField(max_length=20, unique=True) + jobs = models.ManyToManyField(Job, through=JobResponsibilities, + related_name='responsibilities') + + def __str__(self): + return self.description + +# Models for disjunction join promotion low level testing. + + +class FK1(models.Model): + f1 = models.TextField() + f2 = models.TextField() + + +class FK2(models.Model): + f1 = models.TextField() + f2 = models.TextField() + + +class FK3(models.Model): + f1 = models.TextField() + f2 = models.TextField() + + +class BaseA(models.Model): + a = models.ForeignKey(FK1, models.SET_NULL, null=True) + b = models.ForeignKey(FK2, models.SET_NULL, null=True) + c = models.ForeignKey(FK3, models.SET_NULL, null=True) + + +class Identifier(models.Model): + name = models.CharField(max_length=100) + + def __str__(self): + return self.name + + +class Program(models.Model): + identifier = models.OneToOneField(Identifier, models.CASCADE) + + +class Channel(models.Model): + programs = models.ManyToManyField(Program) + identifier = models.OneToOneField(Identifier, models.CASCADE) + + +class Book(models.Model): + title = models.TextField() + chapter = models.ForeignKey('Chapter', models.CASCADE) + + +class Chapter(models.Model): + title = models.TextField() + paragraph = models.ForeignKey('Paragraph', models.CASCADE) + + +class Paragraph(models.Model): + text = models.TextField() + page = models.ManyToManyField('Page') + + +class Page(models.Model): + text = models.TextField() + + +class MyObject(models.Model): + parent = models.ForeignKey('self', models.SET_NULL, null=True, blank=True, related_name='children') + data = models.CharField(max_length=100) + created_at = models.DateTimeField(auto_now_add=True) + +# Models for #17600 regressions + + +class Order(models.Model): + id = models.IntegerField(primary_key=True) + + class Meta: + ordering = ('pk',) + + def __str__(self): + return '%s' % self.pk + + +class OrderItem(models.Model): + order = models.ForeignKey(Order, models.CASCADE, related_name='items') + status = models.IntegerField() + + class Meta: + ordering = ('pk',) + + def __str__(self): + return '%s' % self.pk + + +class BaseUser(models.Model): + pass + + +class Task(models.Model): + title = models.CharField(max_length=10) + owner = models.ForeignKey(BaseUser, models.CASCADE, related_name='owner') + creator = models.ForeignKey(BaseUser, models.CASCADE, related_name='creator') + + def __str__(self): + return self.title + + +class Staff(models.Model): + name = models.CharField(max_length=10) + + def __str__(self): + return self.name + + +class StaffUser(BaseUser): + staff = models.OneToOneField(Staff, models.CASCADE, related_name='user') + + def __str__(self): + return self.staff + + +class Ticket21203Parent(models.Model): + parentid = models.AutoField(primary_key=True) + parent_bool = models.BooleanField(default=True) + created = models.DateTimeField(auto_now=True) + + +class Ticket21203Child(models.Model): + childid = models.AutoField(primary_key=True) + parent = models.ForeignKey(Ticket21203Parent, models.CASCADE) + + +class Person(models.Model): + name = models.CharField(max_length=128) + + +class Company(models.Model): + name = models.CharField(max_length=128) + employees = models.ManyToManyField(Person, related_name='employers', through='Employment') + + def __str__(self): + return self.name + + +class Employment(models.Model): + employer = models.ForeignKey(Company, models.CASCADE) + employee = models.ForeignKey(Person, models.CASCADE) + title = models.CharField(max_length=128) + + +class School(models.Model): + pass + + +class Student(models.Model): + school = models.ForeignKey(School, models.CASCADE) + + +class Classroom(models.Model): + name = models.CharField(max_length=20) + has_blackboard = models.BooleanField(null=True) + school = models.ForeignKey(School, models.CASCADE) + students = models.ManyToManyField(Student, related_name='classroom') + + +class Teacher(models.Model): + schools = models.ManyToManyField(School) + friends = models.ManyToManyField('self') + + +class Ticket23605AParent(models.Model): + pass + + +class Ticket23605A(Ticket23605AParent): + pass + + +class Ticket23605B(models.Model): + modela_fk = models.ForeignKey(Ticket23605A, models.CASCADE) + modelc_fk = models.ForeignKey("Ticket23605C", models.CASCADE) + field_b0 = models.IntegerField(null=True) + field_b1 = models.BooleanField(default=False) + + +class Ticket23605C(models.Model): + field_c0 = models.FloatField() + + +# db_table names have capital letters to ensure they are quoted in queries. +class Individual(models.Model): + alive = models.BooleanField() + + class Meta: + db_table = 'Individual' + + +class RelatedIndividual(models.Model): + related = models.ForeignKey(Individual, models.CASCADE, related_name='related_individual') + + class Meta: + db_table = 'RelatedIndividual' diff --git a/tests/queries/test_explain.py b/tests/queries/test_explain.py new file mode 100644 index 00000000..ad4ca988 --- /dev/null +++ b/tests/queries/test_explain.py @@ -0,0 +1,102 @@ +import unittest + +from django.db import NotSupportedError, connection, transaction +from django.db.models import Count +from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature +from django.test.utils import CaptureQueriesContext + +from .models import Tag + + +@skipUnlessDBFeature('supports_explaining_query_execution') +class ExplainTests(TestCase): + + def test_basic(self): + querysets = [ + Tag.objects.filter(name='test'), + Tag.objects.filter(name='test').select_related('parent'), + Tag.objects.filter(name='test').prefetch_related('children'), + Tag.objects.filter(name='test').annotate(Count('children')), + Tag.objects.filter(name='test').values_list('name'), + Tag.objects.order_by().union(Tag.objects.order_by().filter(name='test')), + Tag.objects.all().select_for_update().filter(name='test'), + ] + supported_formats = connection.features.supported_explain_formats + all_formats = (None,) + tuple(supported_formats) + tuple(f.lower() for f in supported_formats) + for idx, queryset in enumerate(querysets): + for format in all_formats: + with self.subTest(format=format, queryset=idx): + if connection.vendor == 'mysql': + # This does a query and caches the result. + connection.features.needs_explain_extended + with self.assertNumQueries(1), CaptureQueriesContext(connection) as captured_queries: + result = queryset.explain(format=format) + self.assertTrue(captured_queries[0]['sql'].startswith(connection.ops.explain_prefix)) + self.assertIsInstance(result, str) + self.assertTrue(result) + + @skipUnlessDBFeature('validates_explain_options') + def test_unknown_options(self): + with self.assertRaisesMessage(ValueError, 'Unknown options: test, test2'): + Tag.objects.all().explain(test=1, test2=1) + + def test_unknown_format(self): + msg = 'DOES NOT EXIST is not a recognized format.' + if connection.features.supported_explain_formats: + msg += ' Allowed formats: %s' % ', '.join(sorted(connection.features.supported_explain_formats)) + with self.assertRaisesMessage(ValueError, msg): + Tag.objects.all().explain(format='does not exist') + + @unittest.skipUnless(connection.vendor == 'postgresql', 'PostgreSQL specific') + def test_postgres_options(self): + qs = Tag.objects.filter(name='test') + test_options = [ + {'COSTS': False, 'BUFFERS': True, 'ANALYZE': True}, + {'costs': False, 'buffers': True, 'analyze': True}, + {'verbose': True, 'timing': True, 'analyze': True}, + {'verbose': False, 'timing': False, 'analyze': True}, + ] + if connection.pg_version >= 100000: + test_options.append({'summary': True}) + for options in test_options: + with self.subTest(**options), transaction.atomic(): + with CaptureQueriesContext(connection) as captured_queries: + qs.explain(format='text', **options) + self.assertEqual(len(captured_queries), 1) + for name, value in options.items(): + option = '{} {}'.format(name.upper(), 'true' if value else 'false') + self.assertIn(option, captured_queries[0]['sql']) + + @unittest.skipUnless(connection.vendor == 'mysql', 'MySQL specific') + def test_mysql_text_to_traditional(self): + # Initialize the cached property, if needed, to prevent a query for + # the MySQL version during the QuerySet evaluation. + connection.features.needs_explain_extended + with CaptureQueriesContext(connection) as captured_queries: + Tag.objects.filter(name='test').explain(format='text') + self.assertEqual(len(captured_queries), 1) + self.assertIn('FORMAT=TRADITIONAL', captured_queries[0]['sql']) + + @unittest.skipUnless(connection.vendor == 'mysql', 'MySQL < 5.7 specific') + def test_mysql_extended(self): + # Inner skip to avoid module level query for MySQL version. + if not connection.features.needs_explain_extended: + raise unittest.SkipTest('MySQL < 5.7 specific') + qs = Tag.objects.filter(name='test') + with CaptureQueriesContext(connection) as captured_queries: + qs.explain(format='json') + self.assertEqual(len(captured_queries), 1) + self.assertNotIn('EXTENDED', captured_queries[0]['sql']) + with CaptureQueriesContext(connection) as captured_queries: + qs.explain(format='text') + self.assertEqual(len(captured_queries), 1) + self.assertNotIn('EXTENDED', captured_queries[0]['sql']) + + +@skipIfDBFeature('supports_explaining_query_execution') +class ExplainUnsupportedTests(TestCase): + + def test_message(self): + msg = 'This backend does not support explaining query execution.' + with self.assertRaisesMessage(NotSupportedError, msg): + Tag.objects.filter(name='test').explain() diff --git a/tests/queries/test_iterator.py b/tests/queries/test_iterator.py new file mode 100644 index 00000000..56f42c21 --- /dev/null +++ b/tests/queries/test_iterator.py @@ -0,0 +1,39 @@ +import datetime +from unittest import mock + +from django.db.models.sql.compiler import cursor_iter +from django.test import TestCase + +from .models import Article + + +class QuerySetIteratorTests(TestCase): + itersize_index_in_mock_args = 3 + + @classmethod + def setUpTestData(cls): + Article.objects.create(name='Article 1', created=datetime.datetime.now()) + Article.objects.create(name='Article 2', created=datetime.datetime.now()) + + def test_iterator_invalid_chunk_size(self): + for size in (0, -1): + with self.subTest(size=size): + with self.assertRaisesMessage(ValueError, 'Chunk size must be strictly positive.'): + Article.objects.iterator(chunk_size=size) + + def test_default_iterator_chunk_size(self): + qs = Article.objects.iterator() + with mock.patch('django.db.models.sql.compiler.cursor_iter', side_effect=cursor_iter) as cursor_iter_mock: + next(qs) + self.assertEqual(cursor_iter_mock.call_count, 1) + mock_args, _mock_kwargs = cursor_iter_mock.call_args + self.assertEqual(mock_args[self.itersize_index_in_mock_args], 2000) + + def test_iterator_chunk_size(self): + batch_size = 3 + qs = Article.objects.iterator(chunk_size=batch_size) + with mock.patch('django.db.models.sql.compiler.cursor_iter', side_effect=cursor_iter) as cursor_iter_mock: + next(qs) + self.assertEqual(cursor_iter_mock.call_count, 1) + mock_args, _mock_kwargs = cursor_iter_mock.call_args + self.assertEqual(mock_args[self.itersize_index_in_mock_args], batch_size) diff --git a/tests/queries/test_q.py b/tests/queries/test_q.py new file mode 100644 index 00000000..9adff07e --- /dev/null +++ b/tests/queries/test_q.py @@ -0,0 +1,105 @@ +from django.db.models import F, Q +from django.test import SimpleTestCase + + +class QTests(SimpleTestCase): + def test_combine_and_empty(self): + q = Q(x=1) + self.assertEqual(q & Q(), q) + self.assertEqual(Q() & q, q) + + def test_combine_and_both_empty(self): + self.assertEqual(Q() & Q(), Q()) + + def test_combine_or_empty(self): + q = Q(x=1) + self.assertEqual(q | Q(), q) + self.assertEqual(Q() | q, q) + + def test_combine_or_both_empty(self): + self.assertEqual(Q() | Q(), Q()) + + def test_combine_not_q_object(self): + obj = object() + q = Q(x=1) + with self.assertRaisesMessage(TypeError, str(obj)): + q | obj + with self.assertRaisesMessage(TypeError, str(obj)): + q & obj + + def test_deconstruct(self): + q = Q(price__gt=F('discounted_price')) + path, args, kwargs = q.deconstruct() + self.assertEqual(path, 'django.db.models.Q') + self.assertEqual(args, ()) + self.assertEqual(kwargs, {'price__gt': F('discounted_price')}) + + def test_deconstruct_negated(self): + q = ~Q(price__gt=F('discounted_price')) + path, args, kwargs = q.deconstruct() + self.assertEqual(args, ()) + self.assertEqual(kwargs, { + 'price__gt': F('discounted_price'), + '_negated': True, + }) + + def test_deconstruct_or(self): + q1 = Q(price__gt=F('discounted_price')) + q2 = Q(price=F('discounted_price')) + q = q1 | q2 + path, args, kwargs = q.deconstruct() + self.assertEqual(args, ( + ('price__gt', F('discounted_price')), + ('price', F('discounted_price')), + )) + self.assertEqual(kwargs, {'_connector': 'OR'}) + + def test_deconstruct_and(self): + q1 = Q(price__gt=F('discounted_price')) + q2 = Q(price=F('discounted_price')) + q = q1 & q2 + path, args, kwargs = q.deconstruct() + self.assertEqual(args, ( + ('price__gt', F('discounted_price')), + ('price', F('discounted_price')), + )) + self.assertEqual(kwargs, {}) + + def test_deconstruct_multiple_kwargs(self): + q = Q(price__gt=F('discounted_price'), price=F('discounted_price')) + path, args, kwargs = q.deconstruct() + self.assertEqual(args, ( + ('price', F('discounted_price')), + ('price__gt', F('discounted_price')), + )) + self.assertEqual(kwargs, {}) + + def test_deconstruct_nested(self): + q = Q(Q(price__gt=F('discounted_price'))) + path, args, kwargs = q.deconstruct() + self.assertEqual(args, (Q(price__gt=F('discounted_price')),)) + self.assertEqual(kwargs, {}) + + def test_reconstruct(self): + q = Q(price__gt=F('discounted_price')) + path, args, kwargs = q.deconstruct() + self.assertEqual(Q(*args, **kwargs), q) + + def test_reconstruct_negated(self): + q = ~Q(price__gt=F('discounted_price')) + path, args, kwargs = q.deconstruct() + self.assertEqual(Q(*args, **kwargs), q) + + def test_reconstruct_or(self): + q1 = Q(price__gt=F('discounted_price')) + q2 = Q(price=F('discounted_price')) + q = q1 | q2 + path, args, kwargs = q.deconstruct() + self.assertEqual(Q(*args, **kwargs), q) + + def test_reconstruct_and(self): + q1 = Q(price__gt=F('discounted_price')) + q2 = Q(price=F('discounted_price')) + q = q1 & q2 + path, args, kwargs = q.deconstruct() + self.assertEqual(Q(*args, **kwargs), q) diff --git a/tests/queries/test_qs_combinators.py b/tests/queries/test_qs_combinators.py new file mode 100644 index 00000000..b3abfc53 --- /dev/null +++ b/tests/queries/test_qs_combinators.py @@ -0,0 +1,209 @@ +from django.db.models import Exists, F, IntegerField, OuterRef, Value +from django.db.utils import DatabaseError, NotSupportedError +from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature + +from .models import Number, ReservedName + + +@skipUnlessDBFeature('supports_select_union') +class QuerySetSetOperationTests(TestCase): + @classmethod + def setUpTestData(cls): + Number.objects.bulk_create(Number(num=i) for i in range(10)) + + def number_transform(self, value): + return value.num + + def assertNumbersEqual(self, queryset, expected_numbers, ordered=True): + self.assertQuerysetEqual(queryset, expected_numbers, self.number_transform, ordered) + + def test_simple_union(self): + qs1 = Number.objects.filter(num__lte=1) + qs2 = Number.objects.filter(num__gte=8) + qs3 = Number.objects.filter(num=5) + self.assertNumbersEqual(qs1.union(qs2, qs3), [0, 1, 5, 8, 9], ordered=False) + + @skipUnlessDBFeature('supports_select_intersection') + def test_simple_intersection(self): + qs1 = Number.objects.filter(num__lte=5) + qs2 = Number.objects.filter(num__gte=5) + qs3 = Number.objects.filter(num__gte=4, num__lte=6) + self.assertNumbersEqual(qs1.intersection(qs2, qs3), [5], ordered=False) + + @skipUnlessDBFeature('supports_select_intersection') + def test_intersection_with_values(self): + ReservedName.objects.create(name='a', order=2) + qs1 = ReservedName.objects.all() + reserved_name = qs1.intersection(qs1).values('name', 'order', 'id').get() + self.assertEqual(reserved_name['name'], 'a') + self.assertEqual(reserved_name['order'], 2) + reserved_name = qs1.intersection(qs1).values_list('name', 'order', 'id').get() + self.assertEqual(reserved_name[:2], ('a', 2)) + + @skipUnlessDBFeature('supports_select_difference') + def test_simple_difference(self): + qs1 = Number.objects.filter(num__lte=5) + qs2 = Number.objects.filter(num__lte=4) + self.assertNumbersEqual(qs1.difference(qs2), [5], ordered=False) + + def test_union_distinct(self): + qs1 = Number.objects.all() + qs2 = Number.objects.all() + self.assertEqual(len(list(qs1.union(qs2, all=True))), 20) + self.assertEqual(len(list(qs1.union(qs2))), 10) + + @skipUnlessDBFeature('supports_select_intersection') + def test_intersection_with_empty_qs(self): + qs1 = Number.objects.all() + qs2 = Number.objects.none() + qs3 = Number.objects.filter(pk__in=[]) + self.assertEqual(len(qs1.intersection(qs2)), 0) + self.assertEqual(len(qs1.intersection(qs3)), 0) + self.assertEqual(len(qs2.intersection(qs1)), 0) + self.assertEqual(len(qs3.intersection(qs1)), 0) + self.assertEqual(len(qs2.intersection(qs2)), 0) + self.assertEqual(len(qs3.intersection(qs3)), 0) + + @skipUnlessDBFeature('supports_select_difference') + def test_difference_with_empty_qs(self): + qs1 = Number.objects.all() + qs2 = Number.objects.none() + qs3 = Number.objects.filter(pk__in=[]) + self.assertEqual(len(qs1.difference(qs2)), 10) + self.assertEqual(len(qs1.difference(qs3)), 10) + self.assertEqual(len(qs2.difference(qs1)), 0) + self.assertEqual(len(qs3.difference(qs1)), 0) + self.assertEqual(len(qs2.difference(qs2)), 0) + self.assertEqual(len(qs3.difference(qs3)), 0) + + @skipUnlessDBFeature('supports_select_difference') + def test_difference_with_values(self): + ReservedName.objects.create(name='a', order=2) + qs1 = ReservedName.objects.all() + qs2 = ReservedName.objects.none() + reserved_name = qs1.difference(qs2).values('name', 'order', 'id').get() + self.assertEqual(reserved_name['name'], 'a') + self.assertEqual(reserved_name['order'], 2) + reserved_name = qs1.difference(qs2).values_list('name', 'order', 'id').get() + self.assertEqual(reserved_name[:2], ('a', 2)) + + def test_union_with_empty_qs(self): + qs1 = Number.objects.all() + qs2 = Number.objects.none() + qs3 = Number.objects.filter(pk__in=[]) + self.assertEqual(len(qs1.union(qs2)), 10) + self.assertEqual(len(qs2.union(qs1)), 10) + self.assertEqual(len(qs1.union(qs3)), 10) + self.assertEqual(len(qs3.union(qs1)), 10) + self.assertEqual(len(qs2.union(qs1, qs1, qs1)), 10) + self.assertEqual(len(qs2.union(qs1, qs1, all=True)), 20) + self.assertEqual(len(qs2.union(qs2)), 0) + self.assertEqual(len(qs3.union(qs3)), 0) + + #def test_limits(self): + # qs1 = Number.objects.all() + # qs2 = Number.objects.all() + # self.assertEqual(len(list(qs1.union(qs2)[:2])), 2) + + def test_ordering(self): + qs1 = Number.objects.filter(num__lte=1) + qs2 = Number.objects.filter(num__gte=2, num__lte=3) + self.assertNumbersEqual(qs1.union(qs2).order_by('-num'), [3, 2, 1, 0]) + + def test_union_with_values(self): + ReservedName.objects.create(name='a', order=2) + qs1 = ReservedName.objects.all() + reserved_name = qs1.union(qs1).values('name', 'order', 'id').get() + self.assertEqual(reserved_name['name'], 'a') + self.assertEqual(reserved_name['order'], 2) + reserved_name = qs1.union(qs1).values_list('name', 'order', 'id').get() + self.assertEqual(reserved_name[:2], ('a', 2)) + + def test_union_with_two_annotated_values_list(self): + qs1 = Number.objects.filter(num=1).annotate( + count=Value(0, IntegerField()), + ).values_list('num', 'count') + qs2 = Number.objects.filter(num=2).values('pk').annotate( + count=F('num'), + ).annotate( + num=Value(1, IntegerField()), + ).values_list('num', 'count') + self.assertCountEqual(qs1.union(qs2), [(1, 0), (2, 1)]) + + def test_union_with_values_list_on_annotated_and_unannotated(self): + ReservedName.objects.create(name='rn1', order=1) + qs1 = Number.objects.annotate( + has_reserved_name=Exists(ReservedName.objects.filter(order=OuterRef('num'))) + ).filter(has_reserved_name=True) + qs2 = Number.objects.filter(num=9) + self.assertCountEqual(qs1.union(qs2).values_list('num', flat=True), [1, 9]) + + def test_count_union(self): + qs1 = Number.objects.filter(num__lte=1).values('num') + qs2 = Number.objects.filter(num__gte=2, num__lte=3).values('num') + self.assertEqual(qs1.union(qs2).count(), 4) + + def test_count_union_empty_result(self): + qs = Number.objects.filter(pk__in=[]) + self.assertEqual(qs.union(qs).count(), 0) + + @skipUnlessDBFeature('supports_select_difference') + def test_count_difference(self): + qs1 = Number.objects.filter(num__lt=10) + qs2 = Number.objects.filter(num__lt=9) + self.assertEqual(qs1.difference(qs2).count(), 1) + + @skipUnlessDBFeature('supports_select_intersection') + def test_count_intersection(self): + qs1 = Number.objects.filter(num__gte=5) + qs2 = Number.objects.filter(num__lte=5) + self.assertEqual(qs1.intersection(qs2).count(), 1) + + @skipUnlessDBFeature('supports_slicing_ordering_in_compound') + def test_ordering_subqueries(self): + qs1 = Number.objects.order_by('num')[:2] + qs2 = Number.objects.order_by('-num')[:2] + self.assertNumbersEqual(qs1.union(qs2).order_by('-num')[:4], [9, 8, 1, 0]) + + @skipIfDBFeature('supports_slicing_ordering_in_compound') + def test_unsupported_ordering_slicing_raises_db_error(self): + qs1 = Number.objects.all() + qs2 = Number.objects.all() + msg = 'LIMIT/OFFSET not allowed in subqueries of compound statements' + with self.assertRaisesMessage(DatabaseError, msg): + list(qs1.union(qs2[:10])) + msg = 'ORDER BY not allowed in subqueries of compound statements' + with self.assertRaisesMessage(DatabaseError, msg): + list(qs1.order_by('id').union(qs2)) + + @skipIfDBFeature('supports_select_intersection') + def test_unsupported_intersection_raises_db_error(self): + qs1 = Number.objects.all() + qs2 = Number.objects.all() + msg = 'intersection is not supported on this database backend' + with self.assertRaisesMessage(NotSupportedError, msg): + list(qs1.intersection(qs2)) + + def test_combining_multiple_models(self): + ReservedName.objects.create(name='99 little bugs', order=99) + qs1 = Number.objects.filter(num=1).values_list('num', flat=True) + qs2 = ReservedName.objects.values_list('order') + self.assertEqual(list(qs1.union(qs2).order_by('num')), [1, 99]) + + def test_order_raises_on_non_selected_column(self): + qs1 = Number.objects.filter().annotate( + annotation=Value(1, IntegerField()), + ).values('annotation', num2=F('num')) + qs2 = Number.objects.filter().values('id', 'num') + # Should not raise + list(qs1.union(qs2).order_by('annotation')) + list(qs1.union(qs2).order_by('num2')) + msg = 'ORDER BY term does not match any column in the result set' + # 'id' is not part of the select + with self.assertRaisesMessage(DatabaseError, msg): + list(qs1.union(qs2).order_by('id')) + # 'num' got realiased to num2 + with self.assertRaisesMessage(DatabaseError, msg): + list(qs1.union(qs2).order_by('num')) + # switched order, now 'exists' again: + list(qs2.union(qs1).order_by('num')) diff --git a/tests/queries/tests.py b/tests/queries/tests.py new file mode 100644 index 00000000..c592086f --- /dev/null +++ b/tests/queries/tests.py @@ -0,0 +1,3899 @@ +import datetime +import pickle +import unittest +from collections import OrderedDict +from operator import attrgetter + +from django.core.exceptions import EmptyResultSet, FieldError +from django.db import DEFAULT_DB_ALIAS, connection +from django.db.models import Count, F, Q +from django.db.models.sql.constants import LOUTER +from django.db.models.sql.where import NothingNode, WhereNode +from django.test import TestCase, skipUnlessDBFeature +from django.test.utils import CaptureQueriesContext + +from .models import ( + FK1, Annotation, Article, Author, BaseA, Book, CategoryItem, + CategoryRelationship, Celebrity, Channel, Chapter, Child, ChildObjectA, + Classroom, CommonMixedCaseForeignKeys, Company, Cover, CustomPk, + CustomPkTag, Detail, DumbCategory, Eaten, Employment, ExtraInfo, Fan, Food, + Identifier, Individual, Item, Job, JobResponsibilities, Join, LeafA, LeafB, + LoopX, LoopZ, ManagedModel, Member, MixedCaseDbColumnCategoryItem, + MixedCaseFieldCategoryItem, ModelA, ModelB, ModelC, ModelD, MyObject, + NamedCategory, Node, Note, NullableName, Number, ObjectA, ObjectB, ObjectC, + OneToOneCategory, Order, OrderItem, Page, Paragraph, Person, Plaything, + PointerA, Program, ProxyCategory, ProxyObjectA, ProxyObjectB, Ranking, + Related, RelatedIndividual, RelatedObject, Report, ReportComment, + ReservedName, Responsibility, School, SharedConnection, SimpleCategory, + SingleObject, SpecialCategory, Staff, StaffUser, Student, Tag, Task, + Teacher, Ticket21203Child, Ticket21203Parent, Ticket23605A, Ticket23605B, + Ticket23605C, TvChef, Valid, X, +) + + +class Queries1Tests(TestCase): + @classmethod + def setUpTestData(cls): + generic = NamedCategory.objects.create(name="Generic") + cls.t1 = Tag.objects.create(name='t1', category=generic) + cls.t2 = Tag.objects.create(name='t2', parent=cls.t1, category=generic) + cls.t3 = Tag.objects.create(name='t3', parent=cls.t1) + t4 = Tag.objects.create(name='t4', parent=cls.t3) + cls.t5 = Tag.objects.create(name='t5', parent=cls.t3) + + cls.n1 = Note.objects.create(note='n1', misc='foo', id=1) + n2 = Note.objects.create(note='n2', misc='bar', id=2) + cls.n3 = Note.objects.create(note='n3', misc='foo', id=3) + + ann1 = Annotation.objects.create(name='a1', tag=cls.t1) + ann1.notes.add(cls.n1) + ann2 = Annotation.objects.create(name='a2', tag=t4) + ann2.notes.add(n2, cls.n3) + + # Create these out of order so that sorting by 'id' will be different to sorting + # by 'info'. Helps detect some problems later. + cls.e2 = ExtraInfo.objects.create(info='e2', note=n2, value=41) + e1 = ExtraInfo.objects.create(info='e1', note=cls.n1, value=42) + + cls.a1 = Author.objects.create(name='a1', num=1001, extra=e1) + cls.a2 = Author.objects.create(name='a2', num=2002, extra=e1) + a3 = Author.objects.create(name='a3', num=3003, extra=cls.e2) + cls.a4 = Author.objects.create(name='a4', num=4004, extra=cls.e2) + + cls.time1 = datetime.datetime(2007, 12, 19, 22, 25, 0) + cls.time2 = datetime.datetime(2007, 12, 19, 21, 0, 0) + time3 = datetime.datetime(2007, 12, 20, 22, 25, 0) + time4 = datetime.datetime(2007, 12, 20, 21, 0, 0) + cls.i1 = Item.objects.create(name='one', created=cls.time1, modified=cls.time1, creator=cls.a1, note=cls.n3) + cls.i1.tags.set([cls.t1, cls.t2]) + cls.i2 = Item.objects.create(name='two', created=cls.time2, creator=cls.a2, note=n2) + cls.i2.tags.set([cls.t1, cls.t3]) + cls.i3 = Item.objects.create(name='three', created=time3, creator=cls.a2, note=cls.n3) + i4 = Item.objects.create(name='four', created=time4, creator=cls.a4, note=cls.n3) + i4.tags.set([t4]) + + cls.r1 = Report.objects.create(name='r1', creator=cls.a1) + Report.objects.create(name='r2', creator=a3) + Report.objects.create(name='r3') + + # Ordering by 'rank' gives us rank2, rank1, rank3. Ordering by the Meta.ordering + # will be rank3, rank2, rank1. + cls.rank1 = Ranking.objects.create(rank=2, author=cls.a2) + + Cover.objects.create(title="first", item=i4) + Cover.objects.create(title="second", item=cls.i2) + + def test_subquery_condition(self): + qs1 = Tag.objects.filter(pk__lte=0) + qs2 = Tag.objects.filter(parent__in=qs1) + qs3 = Tag.objects.filter(parent__in=qs2) + self.assertEqual(qs3.query.subq_aliases, {'T', 'U', 'V'}) + self.assertIn('v0', str(qs3.query).lower()) + qs4 = qs3.filter(parent__in=qs1) + self.assertEqual(qs4.query.subq_aliases, {'T', 'U', 'V'}) + # It is possible to reuse U for the second subquery, no need to use W. + self.assertNotIn('w0', str(qs4.query).lower()) + # So, 'U0."id"' is referenced twice. + self.assertTrue(str(qs4.query).lower().count('u0'), 2) + + def test_ticket1050(self): + self.assertQuerysetEqual( + Item.objects.filter(tags__isnull=True), + [''] + ) + self.assertQuerysetEqual( + Item.objects.filter(tags__id__isnull=True), + [''] + ) + + def test_ticket1801(self): + self.assertQuerysetEqual( + Author.objects.filter(item=self.i2), + [''] + ) + self.assertQuerysetEqual( + Author.objects.filter(item=self.i3), + [''] + ) + self.assertQuerysetEqual( + Author.objects.filter(item=self.i2) & Author.objects.filter(item=self.i3), + [''] + ) + + def test_ticket2306(self): + # Checking that no join types are "left outer" joins. + query = Item.objects.filter(tags=self.t2).query + self.assertNotIn(LOUTER, [x.join_type for x in query.alias_map.values()]) + + self.assertQuerysetEqual( + Item.objects.filter(Q(tags=self.t1)).order_by('name'), + ['', ''] + ) + self.assertQuerysetEqual( + Item.objects.filter(Q(tags=self.t1)).filter(Q(tags=self.t2)), + [''] + ) + self.assertQuerysetEqual( + Item.objects.filter(Q(tags=self.t1)).filter(Q(creator__name='fred') | Q(tags=self.t2)), + [''] + ) + + # Each filter call is processed "at once" against a single table, so this is + # different from the previous example as it tries to find tags that are two + # things at once (rather than two tags). + self.assertQuerysetEqual( + Item.objects.filter(Q(tags=self.t1) & Q(tags=self.t2)), + [] + ) + self.assertQuerysetEqual( + Item.objects.filter(Q(tags=self.t1), Q(creator__name='fred') | Q(tags=self.t2)), + [] + ) + + qs = Author.objects.filter(ranking__rank=2, ranking__id=self.rank1.id) + self.assertQuerysetEqual(list(qs), ['']) + self.assertEqual(2, qs.query.count_active_tables(), 2) + qs = Author.objects.filter(ranking__rank=2).filter(ranking__id=self.rank1.id) + self.assertEqual(qs.query.count_active_tables(), 3) + + def test_ticket4464(self): + self.assertQuerysetEqual( + Item.objects.filter(tags=self.t1).filter(tags=self.t2), + [''] + ) + self.assertQuerysetEqual( + Item.objects.filter(tags__in=[self.t1, self.t2]).distinct().order_by('name'), + ['', ''] + ) + self.assertQuerysetEqual( + Item.objects.filter(tags__in=[self.t1, self.t2]).filter(tags=self.t3), + [''] + ) + + # Make sure .distinct() works with slicing (this was broken in Oracle). + self.assertQuerysetEqual( + Item.objects.filter(tags__in=[self.t1, self.t2]).order_by('name')[:3], + ['', '', ''] + ) + self.assertQuerysetEqual( + Item.objects.filter(tags__in=[self.t1, self.t2]).distinct().order_by('name')[:3], + ['', ''] + ) + + def test_tickets_2080_3592(self): + self.assertQuerysetEqual( + Author.objects.filter(item__name='one') | Author.objects.filter(name='a3'), + ['', ''] + ) + self.assertQuerysetEqual( + Author.objects.filter(Q(item__name='one') | Q(name='a3')), + ['', ''] + ) + self.assertQuerysetEqual( + Author.objects.filter(Q(name='a3') | Q(item__name='one')), + ['', ''] + ) + self.assertQuerysetEqual( + Author.objects.filter(Q(item__name='three') | Q(report__name='r3')), + [''] + ) + + def test_ticket6074(self): + # Merging two empty result sets shouldn't leave a queryset with no constraints + # (which would match everything). + self.assertQuerysetEqual(Author.objects.filter(Q(id__in=[])), []) + self.assertQuerysetEqual( + Author.objects.filter(Q(id__in=[]) | Q(id__in=[])), + [] + ) + + def test_tickets_1878_2939(self): + self.assertEqual(Item.objects.values('creator').distinct().count(), 3) + + # Create something with a duplicate 'name' so that we can test multi-column + # cases (which require some tricky SQL transformations under the covers). + xx = Item(name='four', created=self.time1, creator=self.a2, note=self.n1) + xx.save() + self.assertEqual( + Item.objects.exclude(name='two').values('creator', 'name').distinct().count(), + 4 + ) + self.assertEqual( + ( + Item.objects + .exclude(name='two') + .extra(select={'foo': '%s'}, select_params=(1,)) + .values('creator', 'name', 'foo') + .distinct() + .count() + ), + 4 + ) + self.assertEqual( + ( + Item.objects + .exclude(name='two') + .extra(select={'foo': '%s'}, select_params=(1,)) + .values('creator', 'name') + .distinct() + .count() + ), + 4 + ) + xx.delete() + + def test_ticket7323(self): + self.assertEqual(Item.objects.values('creator', 'name').count(), 4) + + def test_ticket2253(self): + q1 = Item.objects.order_by('name') + q2 = Item.objects.filter(id=self.i1.id) + self.assertQuerysetEqual( + q1, + ['', '', '', ''] + ) + self.assertQuerysetEqual(q2, ['']) + self.assertQuerysetEqual( + (q1 | q2).order_by('name'), + ['', '', '', ''] + ) + self.assertQuerysetEqual((q1 & q2).order_by('name'), ['']) + + q1 = Item.objects.filter(tags=self.t1) + q2 = Item.objects.filter(note=self.n3, tags=self.t2) + q3 = Item.objects.filter(creator=self.a4) + self.assertQuerysetEqual( + ((q1 & q2) | q3).order_by('name'), + ['', ''] + ) + + def test_order_by_tables(self): + q1 = Item.objects.order_by('name') + q2 = Item.objects.filter(id=self.i1.id) + list(q2) + combined_query = (q1 & q2).order_by('name').query + self.assertEqual(len([ + t for t in combined_query.alias_map if combined_query.alias_refcount[t] + ]), 1) + + def test_order_by_join_unref(self): + """ + This test is related to the above one, testing that there aren't + old JOINs in the query. + """ + qs = Celebrity.objects.order_by('greatest_fan__fan_of') + self.assertIn('OUTER JOIN', str(qs.query)) + qs = qs.order_by('id') + self.assertNotIn('OUTER JOIN', str(qs.query)) + + def test_get_clears_ordering(self): + """ + get() should clear ordering for optimization purposes. + """ + with CaptureQueriesContext(connection) as captured_queries: + Author.objects.order_by('name').get(pk=self.a1.pk) + self.assertNotIn('order by', captured_queries[0]['sql'].lower()) + + def test_tickets_4088_4306(self): + self.assertQuerysetEqual( + Report.objects.filter(creator=1001), + [''] + ) + self.assertQuerysetEqual( + Report.objects.filter(creator__num=1001), + [''] + ) + self.assertQuerysetEqual(Report.objects.filter(creator__id=1001), []) + self.assertQuerysetEqual( + Report.objects.filter(creator__id=self.a1.id), + [''] + ) + self.assertQuerysetEqual( + Report.objects.filter(creator__name='a1'), + [''] + ) + + def test_ticket4510(self): + self.assertQuerysetEqual( + Author.objects.filter(report__name='r1'), + [''] + ) + + def test_ticket7378(self): + self.assertQuerysetEqual(self.a1.report_set.all(), ['']) + + def test_tickets_5324_6704(self): + self.assertQuerysetEqual( + Item.objects.filter(tags__name='t4'), + [''] + ) + self.assertQuerysetEqual( + Item.objects.exclude(tags__name='t4').order_by('name').distinct(), + ['', '', ''] + ) + self.assertQuerysetEqual( + Item.objects.exclude(tags__name='t4').order_by('name').distinct().reverse(), + ['', '', ''] + ) + self.assertQuerysetEqual( + Author.objects.exclude(item__name='one').distinct().order_by('name'), + ['', '', ''] + ) + + # Excluding across a m2m relation when there is more than one related + # object associated was problematic. + self.assertQuerysetEqual( + Item.objects.exclude(tags__name='t1').order_by('name'), + ['', ''] + ) + self.assertQuerysetEqual( + Item.objects.exclude(tags__name='t1').exclude(tags__name='t4'), + [''] + ) + + # Excluding from a relation that cannot be NULL should not use outer joins. + query = Item.objects.exclude(creator__in=[self.a1, self.a2]).query + self.assertNotIn(LOUTER, [x.join_type for x in query.alias_map.values()]) + + # Similarly, when one of the joins cannot possibly, ever, involve NULL + # values (Author -> ExtraInfo, in the following), it should never be + # promoted to a left outer join. So the following query should only + # involve one "left outer" join (Author -> Item is 0-to-many). + qs = Author.objects.filter(id=self.a1.id).filter(Q(extra__note=self.n1) | Q(item__note=self.n3)) + self.assertEqual( + len([ + x for x in qs.query.alias_map.values() + if x.join_type == LOUTER and qs.query.alias_refcount[x.table_alias] + ]), + 1 + ) + + # The previous changes shouldn't affect nullable foreign key joins. + self.assertQuerysetEqual( + Tag.objects.filter(parent__isnull=True).order_by('name'), + [''] + ) + self.assertQuerysetEqual( + Tag.objects.exclude(parent__isnull=True).order_by('name'), + ['', '', '', ''] + ) + self.assertQuerysetEqual( + Tag.objects.exclude(Q(parent__name='t1') | Q(parent__isnull=True)).order_by('name'), + ['', ''] + ) + self.assertQuerysetEqual( + Tag.objects.exclude(Q(parent__isnull=True) | Q(parent__name='t1')).order_by('name'), + ['', ''] + ) + self.assertQuerysetEqual( + Tag.objects.exclude(Q(parent__parent__isnull=True)).order_by('name'), + ['', ''] + ) + self.assertQuerysetEqual( + Tag.objects.filter(~Q(parent__parent__isnull=True)).order_by('name'), + ['', ''] + ) + + def test_ticket2091(self): + t = Tag.objects.get(name='t4') + self.assertQuerysetEqual( + Item.objects.filter(tags__in=[t]), + [''] + ) + + def test_avoid_infinite_loop_on_too_many_subqueries(self): + x = Tag.objects.filter(pk=1) + local_recursion_limit = 127 + msg = 'Maximum recursion depth exceeded: too many subqueries.' + with self.assertRaisesMessage(RuntimeError, msg): + for i in range(local_recursion_limit * 2): + x = Tag.objects.filter(pk__in=x) + + def test_reasonable_number_of_subq_aliases(self): + x = Tag.objects.filter(pk=1) + for _ in range(20): + x = Tag.objects.filter(pk__in=x) + self.assertEqual( + x.query.subq_aliases, { + 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'AA', 'AB', 'AC', 'AD', + 'AE', 'AF', 'AG', 'AH', 'AI', 'AJ', 'AK', 'AL', 'AM', 'AN', + } + ) + + def test_heterogeneous_qs_combination(self): + # Combining querysets built on different models should behave in a well-defined + # fashion. We raise an error. + with self.assertRaisesMessage(AssertionError, 'Cannot combine queries on two different base models.'): + Author.objects.all() & Tag.objects.all() + with self.assertRaisesMessage(AssertionError, 'Cannot combine queries on two different base models.'): + Author.objects.all() | Tag.objects.all() + + def test_ticket3141(self): + self.assertEqual(Author.objects.extra(select={'foo': '1'}).count(), 4) + self.assertEqual( + Author.objects.extra(select={'foo': '%s'}, select_params=(1,)).count(), + 4 + ) + + def test_ticket2400(self): + self.assertQuerysetEqual( + Author.objects.filter(item__isnull=True), + [''] + ) + self.assertQuerysetEqual( + Tag.objects.filter(item__isnull=True), + [''] + ) + + def test_ticket2496(self): + self.assertQuerysetEqual( + Item.objects.extra(tables=['queries_author']).select_related().order_by('name')[:1], + [''] + ) + + def test_error_raised_on_filter_with_dictionary(self): + with self.assertRaisesMessage(FieldError, 'Cannot parse keyword query as dict'): + Note.objects.filter({'note': 'n1', 'misc': 'foo'}) + + def test_tickets_2076_7256(self): + # Ordering on related tables should be possible, even if the table is + # not otherwise involved. + self.assertQuerysetEqual( + Item.objects.order_by('note__note', 'name'), + ['', '', '', ''] + ) + + # Ordering on a related field should use the remote model's default + # ordering as a final step. + self.assertQuerysetEqual( + Author.objects.order_by('extra', '-name'), + ['', '', '', ''] + ) + + # Using remote model default ordering can span multiple models (in this + # case, Cover is ordered by Item's default, which uses Note's default). + self.assertQuerysetEqual( + Cover.objects.all(), + ['', ''] + ) + + # If the remote model does not have a default ordering, we order by its 'id' + # field. + self.assertQuerysetEqual( + Item.objects.order_by('creator', 'name'), + ['', '', '', ''] + ) + + # Ordering by a many-valued attribute (e.g. a many-to-many or reverse + # ForeignKey) is legal, but the results might not make sense. That + # isn't Django's problem. Garbage in, garbage out. + self.assertQuerysetEqual( + Item.objects.filter(tags__isnull=False).order_by('tags', 'id'), + ['', '', '', '', ''] + ) + + # If we replace the default ordering, Django adjusts the required + # tables automatically. Item normally requires a join with Note to do + # the default ordering, but that isn't needed here. + qs = Item.objects.order_by('name') + self.assertQuerysetEqual( + qs, + ['', '', '', ''] + ) + self.assertEqual(len(qs.query.alias_map), 1) + + def test_tickets_2874_3002(self): + qs = Item.objects.select_related().order_by('note__note', 'name') + self.assertQuerysetEqual( + qs, + ['', '', '', ''] + ) + + # This is also a good select_related() test because there are multiple + # Note entries in the SQL. The two Note items should be different. + self.assertTrue(repr(qs[0].note), '') + self.assertEqual(repr(qs[0].creator.extra.note), '') + + def test_ticket3037(self): + self.assertQuerysetEqual( + Item.objects.filter(Q(creator__name='a3', name='two') | Q(creator__name='a4', name='four')), + [''] + ) + + def test_tickets_5321_7070(self): + # Ordering columns must be included in the output columns. Note that + # this means results that might otherwise be distinct are not (if there + # are multiple values in the ordering cols), as in this example. This + # isn't a bug; it's a warning to be careful with the selection of + # ordering columns. + self.assertSequenceEqual( + Note.objects.values('misc').distinct().order_by('note', '-misc'), + [{'misc': 'foo'}, {'misc': 'bar'}, {'misc': 'foo'}] + ) + + def test_ticket4358(self): + # If you don't pass any fields to values(), relation fields are + # returned as "foo_id" keys, not "foo". For consistency, you should be + # able to pass "foo_id" in the fields list and have it work, too. We + # actually allow both "foo" and "foo_id". + # The *_id version is returned by default. + self.assertIn('note_id', ExtraInfo.objects.values()[0]) + # You can also pass it in explicitly. + self.assertSequenceEqual(ExtraInfo.objects.values('note_id'), [{'note_id': 1}, {'note_id': 2}]) + # ...or use the field name. + self.assertSequenceEqual(ExtraInfo.objects.values('note'), [{'note': 1}, {'note': 2}]) + + def test_ticket2902(self): + # Parameters can be given to extra_select, *if* you use an OrderedDict. + + # (First we need to know which order the keys fall in "naturally" on + # your system, so we can put things in the wrong way around from + # normal. A normal dict would thus fail.) + s = [('a', '%s'), ('b', '%s')] + params = ['one', 'two'] + if list({'a': 1, 'b': 2}) == ['a', 'b']: + s.reverse() + params.reverse() + + d = Item.objects.extra(select=OrderedDict(s), select_params=params).values('a', 'b')[0] + self.assertEqual(d, {'a': 'one', 'b': 'two'}) + + # Order by the number of tags attached to an item. + qs = ( + Item.objects + .extra(select={ + 'count': 'select count(*) from queries_item_tags where queries_item_tags.item_id = queries_item.id' + }) + .order_by('-count') + ) + self.assertEqual([o.count for o in qs], [2, 2, 1, 0]) + + def test_ticket6154(self): + # Multiple filter statements are joined using "AND" all the time. + + self.assertQuerysetEqual( + Author.objects.filter(id=self.a1.id).filter(Q(extra__note=self.n1) | Q(item__note=self.n3)), + [''] + ) + self.assertQuerysetEqual( + Author.objects.filter(Q(extra__note=self.n1) | Q(item__note=self.n3)).filter(id=self.a1.id), + [''] + ) + + def test_ticket6981(self): + self.assertQuerysetEqual( + Tag.objects.select_related('parent').order_by('name'), + ['', '', '', '', ''] + ) + + def test_ticket9926(self): + self.assertQuerysetEqual( + Tag.objects.select_related("parent", "category").order_by('name'), + ['', '', '', '', ''] + ) + self.assertQuerysetEqual( + Tag.objects.select_related('parent', "parent__category").order_by('name'), + ['', '', '', '', ''] + ) + + def test_tickets_6180_6203(self): + # Dates with limits and/or counts + self.assertEqual(Item.objects.count(), 4) + self.assertEqual(Item.objects.datetimes('created', 'month').count(), 1) + self.assertEqual(Item.objects.datetimes('created', 'day').count(), 2) + self.assertEqual(len(Item.objects.datetimes('created', 'day')), 2) + self.assertEqual(Item.objects.datetimes('created', 'day')[0], datetime.datetime(2007, 12, 19, 0, 0)) + + def test_tickets_7087_12242(self): + # Dates with extra select columns + self.assertQuerysetEqual( + Item.objects.datetimes('created', 'day').extra(select={'a': 1}), + ['datetime.datetime(2007, 12, 19, 0, 0)', 'datetime.datetime(2007, 12, 20, 0, 0)'] + ) + self.assertQuerysetEqual( + Item.objects.extra(select={'a': 1}).datetimes('created', 'day'), + ['datetime.datetime(2007, 12, 19, 0, 0)', 'datetime.datetime(2007, 12, 20, 0, 0)'] + ) + + name = "one" + self.assertQuerysetEqual( + Item.objects.datetimes('created', 'day').extra(where=['name=%s'], params=[name]), + ['datetime.datetime(2007, 12, 19, 0, 0)'] + ) + + self.assertQuerysetEqual( + Item.objects.extra(where=['name=%s'], params=[name]).datetimes('created', 'day'), + ['datetime.datetime(2007, 12, 19, 0, 0)'] + ) + + def test_ticket7155(self): + # Nullable dates + self.assertQuerysetEqual( + Item.objects.datetimes('modified', 'day'), + ['datetime.datetime(2007, 12, 19, 0, 0)'] + ) + + def test_ticket7098(self): + # Make sure semi-deprecated ordering by related models syntax still + # works. + self.assertSequenceEqual( + Item.objects.values('note__note').order_by('queries_note.note', 'id'), + [{'note__note': 'n2'}, {'note__note': 'n3'}, {'note__note': 'n3'}, {'note__note': 'n3'}] + ) + + def test_ticket7096(self): + # Make sure exclude() with multiple conditions continues to work. + self.assertQuerysetEqual( + Tag.objects.filter(parent=self.t1, name='t3').order_by('name'), + [''] + ) + self.assertQuerysetEqual( + Tag.objects.exclude(parent=self.t1, name='t3').order_by('name'), + ['', '', '', ''] + ) + self.assertQuerysetEqual( + Item.objects.exclude(tags__name='t1', name='one').order_by('name').distinct(), + ['', '', ''] + ) + self.assertQuerysetEqual( + Item.objects.filter(name__in=['three', 'four']).exclude(tags__name='t1').order_by('name'), + ['', ''] + ) + + # More twisted cases, involving nested negations. + self.assertQuerysetEqual( + Item.objects.exclude(~Q(tags__name='t1', name='one')), + [''] + ) + self.assertQuerysetEqual( + Item.objects.filter(~Q(tags__name='t1', name='one'), name='two'), + [''] + ) + self.assertQuerysetEqual( + Item.objects.exclude(~Q(tags__name='t1', name='one'), name='two'), + ['', '', ''] + ) + + def test_tickets_7204_7506(self): + # Make sure querysets with related fields can be pickled. If this + # doesn't crash, it's a Good Thing. + pickle.dumps(Item.objects.all()) + + def test_ticket7813(self): + # We should also be able to pickle things that use select_related(). + # The only tricky thing here is to ensure that we do the related + # selections properly after unpickling. + qs = Item.objects.select_related() + query = qs.query.get_compiler(qs.db).as_sql()[0] + query2 = pickle.loads(pickle.dumps(qs.query)) + self.assertEqual( + query2.get_compiler(qs.db).as_sql()[0], + query + ) + + def test_deferred_load_qs_pickling(self): + # Check pickling of deferred-loading querysets + qs = Item.objects.defer('name', 'creator') + q2 = pickle.loads(pickle.dumps(qs)) + self.assertEqual(list(qs), list(q2)) + q3 = pickle.loads(pickle.dumps(qs, pickle.HIGHEST_PROTOCOL)) + self.assertEqual(list(qs), list(q3)) + + def test_ticket7277(self): + self.assertQuerysetEqual( + self.n1.annotation_set.filter( + Q(tag=self.t5) | Q(tag__children=self.t5) | Q(tag__children__children=self.t5) + ), + [''] + ) + + def test_tickets_7448_7707(self): + # Complex objects should be converted to strings before being used in + # lookups. + self.assertQuerysetEqual( + Item.objects.filter(created__in=[self.time1, self.time2]), + ['', ''] + ) + + def test_ticket7235(self): + # An EmptyQuerySet should not raise exceptions if it is filtered. + Eaten.objects.create(meal='m') + q = Eaten.objects.none() + with self.assertNumQueries(0): + self.assertQuerysetEqual(q.all(), []) + self.assertQuerysetEqual(q.filter(meal='m'), []) + self.assertQuerysetEqual(q.exclude(meal='m'), []) + self.assertQuerysetEqual(q.complex_filter({'pk': 1}), []) + self.assertQuerysetEqual(q.select_related('food'), []) + self.assertQuerysetEqual(q.annotate(Count('food')), []) + self.assertQuerysetEqual(q.order_by('meal', 'food'), []) + self.assertQuerysetEqual(q.distinct(), []) + self.assertQuerysetEqual( + q.extra(select={'foo': "1"}), + [] + ) + self.assertQuerysetEqual(q.reverse(), []) + q.query.low_mark = 1 + with self.assertRaisesMessage(AssertionError, 'Cannot change a query once a slice has been taken'): + q.extra(select={'foo': "1"}) + self.assertQuerysetEqual(q.defer('meal'), []) + self.assertQuerysetEqual(q.only('meal'), []) + + def test_ticket7791(self): + # There were "issues" when ordering and distinct-ing on fields related + # via ForeignKeys. + self.assertEqual( + len(Note.objects.order_by('extrainfo__info').distinct()), + 3 + ) + + # Pickling of QuerySets using datetimes() should work. + qs = Item.objects.datetimes('created', 'month') + pickle.loads(pickle.dumps(qs)) + + def test_ticket9997(self): + # If a ValuesList or Values queryset is passed as an inner query, we + # make sure it's only requesting a single value and use that as the + # thing to select. + self.assertQuerysetEqual( + Tag.objects.filter(name__in=Tag.objects.filter(parent=self.t1).values('name')), + ['', ''] + ) + + # Multi-valued values() and values_list() querysets should raise errors. + with self.assertRaisesMessage(TypeError, 'Cannot use multi-field values as a filter value.'): + Tag.objects.filter(name__in=Tag.objects.filter(parent=self.t1).values('name', 'id')) + with self.assertRaisesMessage(TypeError, 'Cannot use multi-field values as a filter value.'): + Tag.objects.filter(name__in=Tag.objects.filter(parent=self.t1).values_list('name', 'id')) + + def test_ticket9985(self): + # qs.values_list(...).values(...) combinations should work. + self.assertSequenceEqual( + Note.objects.values_list("note", flat=True).values("id").order_by("id"), + [{'id': 1}, {'id': 2}, {'id': 3}] + ) + self.assertQuerysetEqual( + Annotation.objects.filter(notes__in=Note.objects.filter(note="n1").values_list('note').values('id')), + [''] + ) + + def test_ticket10205(self): + # When bailing out early because of an empty "__in" filter, we need + # to set things up correctly internally so that subqueries can continue properly. + self.assertEqual(Tag.objects.filter(name__in=()).update(name="foo"), 0) + + def test_ticket10432(self): + # Testing an empty "__in" filter with a generator as the value. + def f(): + return iter([]) + n_obj = Note.objects.all()[0] + + def g(): + yield n_obj.pk + self.assertQuerysetEqual(Note.objects.filter(pk__in=f()), []) + self.assertEqual(list(Note.objects.filter(pk__in=g())), [n_obj]) + + def test_ticket10742(self): + # Queries used in an __in clause don't execute subqueries + + subq = Author.objects.filter(num__lt=3000) + qs = Author.objects.filter(pk__in=subq) + self.assertQuerysetEqual(qs, ['', '']) + + # The subquery result cache should not be populated + self.assertIsNone(subq._result_cache) + + subq = Author.objects.filter(num__lt=3000) + qs = Author.objects.exclude(pk__in=subq) + self.assertQuerysetEqual(qs, ['', '']) + + # The subquery result cache should not be populated + self.assertIsNone(subq._result_cache) + + subq = Author.objects.filter(num__lt=3000) + self.assertQuerysetEqual( + Author.objects.filter(Q(pk__in=subq) & Q(name='a1')), + [''] + ) + + # The subquery result cache should not be populated + self.assertIsNone(subq._result_cache) + + def test_ticket7076(self): + # Excluding shouldn't eliminate NULL entries. + self.assertQuerysetEqual( + Item.objects.exclude(modified=self.time1).order_by('name'), + ['', '', ''] + ) + self.assertQuerysetEqual( + Tag.objects.exclude(parent__name=self.t1.name), + ['', '', ''] + ) + + def test_ticket7181(self): + # Ordering by related tables should accommodate nullable fields (this + # test is a little tricky, since NULL ordering is database dependent. + # Instead, we just count the number of results). + self.assertEqual(len(Tag.objects.order_by('parent__name')), 5) + + # Empty querysets can be merged with others. + self.assertQuerysetEqual( + Note.objects.none() | Note.objects.all(), + ['', '', ''] + ) + self.assertQuerysetEqual( + Note.objects.all() | Note.objects.none(), + ['', '', ''] + ) + self.assertQuerysetEqual(Note.objects.none() & Note.objects.all(), []) + self.assertQuerysetEqual(Note.objects.all() & Note.objects.none(), []) + + def test_ticket9411(self): + # Make sure bump_prefix() (an internal Query method) doesn't (re-)break. It's + # sufficient that this query runs without error. + qs = Tag.objects.values_list('id', flat=True).order_by('id') + qs.query.bump_prefix(qs.query) + first = qs[0] + self.assertEqual(list(qs), list(range(first, first + 5))) + + def test_ticket8439(self): + # Complex combinations of conjunctions, disjunctions and nullable + # relations. + self.assertQuerysetEqual( + Author.objects.filter(Q(item__note__extrainfo=self.e2) | Q(report=self.r1, name='xyz')), + [''] + ) + self.assertQuerysetEqual( + Author.objects.filter(Q(report=self.r1, name='xyz') | Q(item__note__extrainfo=self.e2)), + [''] + ) + self.assertQuerysetEqual( + Annotation.objects.filter(Q(tag__parent=self.t1) | Q(notes__note='n1', name='a1')), + [''] + ) + xx = ExtraInfo.objects.create(info='xx', note=self.n3) + self.assertQuerysetEqual( + Note.objects.filter(Q(extrainfo__author=self.a1) | Q(extrainfo=xx)), + ['', ''] + ) + q = Note.objects.filter(Q(extrainfo__author=self.a1) | Q(extrainfo=xx)).query + self.assertEqual( + len([x for x in q.alias_map.values() if x.join_type == LOUTER and q.alias_refcount[x.table_alias]]), + 1 + ) + + def test_ticket17429(self): + """ + Meta.ordering=None works the same as Meta.ordering=[] + """ + original_ordering = Tag._meta.ordering + Tag._meta.ordering = None + try: + self.assertQuerysetEqual( + Tag.objects.all(), + ['', '', '', '', ''], + ordered=False + ) + finally: + Tag._meta.ordering = original_ordering + + def test_exclude(self): + self.assertQuerysetEqual( + Item.objects.exclude(tags__name='t4'), + [repr(i) for i in Item.objects.filter(~Q(tags__name='t4'))]) + self.assertQuerysetEqual( + Item.objects.exclude(Q(tags__name='t4') | Q(tags__name='t3')), + [repr(i) for i in Item.objects.filter(~(Q(tags__name='t4') | Q(tags__name='t3')))]) + self.assertQuerysetEqual( + Item.objects.exclude(Q(tags__name='t4') | ~Q(tags__name='t3')), + [repr(i) for i in Item.objects.filter(~(Q(tags__name='t4') | ~Q(tags__name='t3')))]) + + def test_nested_exclude(self): + self.assertQuerysetEqual( + Item.objects.exclude(~Q(tags__name='t4')), + [repr(i) for i in Item.objects.filter(~~Q(tags__name='t4'))]) + + def test_double_exclude(self): + self.assertQuerysetEqual( + Item.objects.filter(Q(tags__name='t4')), + [repr(i) for i in Item.objects.filter(~~Q(tags__name='t4'))]) + self.assertQuerysetEqual( + Item.objects.filter(Q(tags__name='t4')), + [repr(i) for i in Item.objects.filter(~Q(~Q(tags__name='t4')))]) + + def test_exclude_in(self): + self.assertQuerysetEqual( + Item.objects.exclude(Q(tags__name__in=['t4', 't3'])), + [repr(i) for i in Item.objects.filter(~Q(tags__name__in=['t4', 't3']))]) + self.assertQuerysetEqual( + Item.objects.filter(Q(tags__name__in=['t4', 't3'])), + [repr(i) for i in Item.objects.filter(~~Q(tags__name__in=['t4', 't3']))]) + + def test_ticket_10790_1(self): + # Querying direct fields with isnull should trim the left outer join. + # It also should not create INNER JOIN. + q = Tag.objects.filter(parent__isnull=True) + + self.assertQuerysetEqual(q, ['']) + self.assertNotIn('JOIN', str(q.query)) + + q = Tag.objects.filter(parent__isnull=False) + + self.assertQuerysetEqual( + q, + ['', '', '', ''], + ) + self.assertNotIn('JOIN', str(q.query)) + + q = Tag.objects.exclude(parent__isnull=True) + self.assertQuerysetEqual( + q, + ['', '', '', ''], + ) + self.assertNotIn('JOIN', str(q.query)) + + q = Tag.objects.exclude(parent__isnull=False) + self.assertQuerysetEqual(q, ['']) + self.assertNotIn('JOIN', str(q.query)) + + q = Tag.objects.exclude(parent__parent__isnull=False) + + self.assertQuerysetEqual( + q, + ['', '', ''], + ) + self.assertEqual(str(q.query).count('LEFT OUTER JOIN'), 1) + self.assertNotIn('INNER JOIN', str(q.query)) + + def test_ticket_10790_2(self): + # Querying across several tables should strip only the last outer join, + # while preserving the preceding inner joins. + q = Tag.objects.filter(parent__parent__isnull=False) + + self.assertQuerysetEqual( + q, + ['', ''], + ) + self.assertEqual(str(q.query).count('LEFT OUTER JOIN'), 0) + self.assertEqual(str(q.query).count('INNER JOIN'), 1) + + # Querying without isnull should not convert anything to left outer join. + q = Tag.objects.filter(parent__parent=self.t1) + self.assertQuerysetEqual( + q, + ['', ''], + ) + self.assertEqual(str(q.query).count('LEFT OUTER JOIN'), 0) + self.assertEqual(str(q.query).count('INNER JOIN'), 1) + + def test_ticket_10790_3(self): + # Querying via indirect fields should populate the left outer join + q = NamedCategory.objects.filter(tag__isnull=True) + self.assertEqual(str(q.query).count('LEFT OUTER JOIN'), 1) + # join to dumbcategory ptr_id + self.assertEqual(str(q.query).count('INNER JOIN'), 1) + self.assertQuerysetEqual(q, []) + + # Querying across several tables should strip only the last join, while + # preserving the preceding left outer joins. + q = NamedCategory.objects.filter(tag__parent__isnull=True) + self.assertEqual(str(q.query).count('INNER JOIN'), 1) + self.assertEqual(str(q.query).count('LEFT OUTER JOIN'), 1) + self.assertQuerysetEqual(q, ['']) + + def test_ticket_10790_4(self): + # Querying across m2m field should not strip the m2m table from join. + q = Author.objects.filter(item__tags__isnull=True) + self.assertQuerysetEqual( + q, + ['', ''], + ) + self.assertEqual(str(q.query).count('LEFT OUTER JOIN'), 2) + self.assertNotIn('INNER JOIN', str(q.query)) + + q = Author.objects.filter(item__tags__parent__isnull=True) + self.assertQuerysetEqual( + q, + ['', '', '', ''], + ) + self.assertEqual(str(q.query).count('LEFT OUTER JOIN'), 3) + self.assertNotIn('INNER JOIN', str(q.query)) + + def test_ticket_10790_5(self): + # Querying with isnull=False across m2m field should not create outer joins + q = Author.objects.filter(item__tags__isnull=False) + self.assertQuerysetEqual( + q, + ['', '', '', '', ''] + ) + self.assertEqual(str(q.query).count('LEFT OUTER JOIN'), 0) + self.assertEqual(str(q.query).count('INNER JOIN'), 2) + + q = Author.objects.filter(item__tags__parent__isnull=False) + self.assertQuerysetEqual( + q, + ['', '', ''] + ) + self.assertEqual(str(q.query).count('LEFT OUTER JOIN'), 0) + self.assertEqual(str(q.query).count('INNER JOIN'), 3) + + q = Author.objects.filter(item__tags__parent__parent__isnull=False) + self.assertQuerysetEqual( + q, + [''] + ) + self.assertEqual(str(q.query).count('LEFT OUTER JOIN'), 0) + self.assertEqual(str(q.query).count('INNER JOIN'), 4) + + def test_ticket_10790_6(self): + # Querying with isnull=True across m2m field should not create inner joins + # and strip last outer join + q = Author.objects.filter(item__tags__parent__parent__isnull=True) + self.assertQuerysetEqual( + q, + ['', '', '', '', + '', ''] + ) + self.assertEqual(str(q.query).count('LEFT OUTER JOIN'), 4) + self.assertEqual(str(q.query).count('INNER JOIN'), 0) + + q = Author.objects.filter(item__tags__parent__isnull=True) + self.assertQuerysetEqual( + q, + ['', '', '', ''] + ) + self.assertEqual(str(q.query).count('LEFT OUTER JOIN'), 3) + self.assertEqual(str(q.query).count('INNER JOIN'), 0) + + def test_ticket_10790_7(self): + # Reverse querying with isnull should not strip the join + q = Author.objects.filter(item__isnull=True) + self.assertQuerysetEqual( + q, + [''] + ) + self.assertEqual(str(q.query).count('LEFT OUTER JOIN'), 1) + self.assertEqual(str(q.query).count('INNER JOIN'), 0) + + q = Author.objects.filter(item__isnull=False) + self.assertQuerysetEqual( + q, + ['', '', '', ''] + ) + self.assertEqual(str(q.query).count('LEFT OUTER JOIN'), 0) + self.assertEqual(str(q.query).count('INNER JOIN'), 1) + + def test_ticket_10790_8(self): + # Querying with combined q-objects should also strip the left outer join + q = Tag.objects.filter(Q(parent__isnull=True) | Q(parent=self.t1)) + self.assertQuerysetEqual( + q, + ['', '', ''] + ) + self.assertEqual(str(q.query).count('LEFT OUTER JOIN'), 0) + self.assertEqual(str(q.query).count('INNER JOIN'), 0) + + def test_ticket_10790_combine(self): + # Combining queries should not re-populate the left outer join + q1 = Tag.objects.filter(parent__isnull=True) + q2 = Tag.objects.filter(parent__isnull=False) + + q3 = q1 | q2 + self.assertQuerysetEqual( + q3, + ['', '', '', '', ''], + ) + self.assertEqual(str(q3.query).count('LEFT OUTER JOIN'), 0) + self.assertEqual(str(q3.query).count('INNER JOIN'), 0) + + q3 = q1 & q2 + self.assertQuerysetEqual(q3, []) + self.assertEqual(str(q3.query).count('LEFT OUTER JOIN'), 0) + self.assertEqual(str(q3.query).count('INNER JOIN'), 0) + + q2 = Tag.objects.filter(parent=self.t1) + q3 = q1 | q2 + self.assertQuerysetEqual( + q3, + ['', '', ''] + ) + self.assertEqual(str(q3.query).count('LEFT OUTER JOIN'), 0) + self.assertEqual(str(q3.query).count('INNER JOIN'), 0) + + q3 = q2 | q1 + self.assertQuerysetEqual( + q3, + ['', '', ''] + ) + self.assertEqual(str(q3.query).count('LEFT OUTER JOIN'), 0) + self.assertEqual(str(q3.query).count('INNER JOIN'), 0) + + q1 = Tag.objects.filter(parent__isnull=True) + q2 = Tag.objects.filter(parent__parent__isnull=True) + + q3 = q1 | q2 + self.assertQuerysetEqual( + q3, + ['', '', ''] + ) + self.assertEqual(str(q3.query).count('LEFT OUTER JOIN'), 1) + self.assertEqual(str(q3.query).count('INNER JOIN'), 0) + + q3 = q2 | q1 + self.assertQuerysetEqual( + q3, + ['', '', ''] + ) + self.assertEqual(str(q3.query).count('LEFT OUTER JOIN'), 1) + self.assertEqual(str(q3.query).count('INNER JOIN'), 0) + + def test_ticket19672(self): + self.assertQuerysetEqual( + Report.objects.filter(Q(creator__isnull=False) & ~Q(creator__extra__value=41)), + [''] + ) + + def test_ticket_20250(self): + # A negated Q along with an annotated queryset failed in Django 1.4 + qs = Author.objects.annotate(Count('item')) + qs = qs.filter(~Q(extra__value=0)) + + self.assertIn('SELECT', str(qs.query)) + self.assertQuerysetEqual( + qs, + ['', '', '', ''] + ) + + def test_lookup_constraint_fielderror(self): + msg = ( + "Cannot resolve keyword 'unknown_field' into field. Choices are: " + "annotation, category, category_id, children, id, item, " + "managedmodel, name, note, parent, parent_id" + ) + with self.assertRaisesMessage(FieldError, msg): + Tag.objects.filter(unknown_field__name='generic') + + def test_common_mixed_case_foreign_keys(self): + """ + Valid query should be generated when fields fetched from joined tables + include FKs whose names only differ by case. + """ + c1 = SimpleCategory.objects.create(name='c1') + c2 = SimpleCategory.objects.create(name='c2') + c3 = SimpleCategory.objects.create(name='c3') + category = CategoryItem.objects.create(category=c1) + mixed_case_field_category = MixedCaseFieldCategoryItem.objects.create(CaTeGoRy=c2) + mixed_case_db_column_category = MixedCaseDbColumnCategoryItem.objects.create(category=c3) + CommonMixedCaseForeignKeys.objects.create( + category=category, + mixed_case_field_category=mixed_case_field_category, + mixed_case_db_column_category=mixed_case_db_column_category, + ) + qs = CommonMixedCaseForeignKeys.objects.values( + 'category', + 'mixed_case_field_category', + 'mixed_case_db_column_category', + 'category__category', + 'mixed_case_field_category__CaTeGoRy', + 'mixed_case_db_column_category__category', + ) + self.assertTrue(qs.first()) + + +class Queries2Tests(TestCase): + @classmethod + def setUpTestData(cls): + Number.objects.create(num=4) + Number.objects.create(num=8) + Number.objects.create(num=12) + + def test_ticket4289(self): + # A slight variation on the restricting the filtering choices by the + # lookup constraints. + self.assertQuerysetEqual(Number.objects.filter(num__lt=4), []) + self.assertQuerysetEqual(Number.objects.filter(num__gt=8, num__lt=12), []) + self.assertQuerysetEqual( + Number.objects.filter(num__gt=8, num__lt=13), + [''] + ) + self.assertQuerysetEqual( + Number.objects.filter(Q(num__lt=4) | Q(num__gt=8, num__lt=12)), + [] + ) + self.assertQuerysetEqual( + Number.objects.filter(Q(num__gt=8, num__lt=12) | Q(num__lt=4)), + [] + ) + self.assertQuerysetEqual( + Number.objects.filter(Q(num__gt=8) & Q(num__lt=12) | Q(num__lt=4)), + [] + ) + self.assertQuerysetEqual( + Number.objects.filter(Q(num__gt=7) & Q(num__lt=12) | Q(num__lt=4)), + [''] + ) + + def test_ticket12239(self): + # Custom lookups are registered to round float values correctly on gte + # and lt IntegerField queries. + self.assertQuerysetEqual( + Number.objects.filter(num__gt=11.9), + [''] + ) + self.assertQuerysetEqual(Number.objects.filter(num__gt=12), []) + self.assertQuerysetEqual(Number.objects.filter(num__gt=12.0), []) + self.assertQuerysetEqual(Number.objects.filter(num__gt=12.1), []) + self.assertQuerysetEqual( + Number.objects.filter(num__lt=12), + ['', ''], + ordered=False + ) + self.assertQuerysetEqual( + Number.objects.filter(num__lt=12.0), + ['', ''], + ordered=False + ) + self.assertQuerysetEqual( + Number.objects.filter(num__lt=12.1), + ['', '', ''], + ordered=False + ) + self.assertQuerysetEqual( + Number.objects.filter(num__gte=11.9), + [''] + ) + self.assertQuerysetEqual( + Number.objects.filter(num__gte=12), + [''] + ) + self.assertQuerysetEqual( + Number.objects.filter(num__gte=12.0), + [''] + ) + self.assertQuerysetEqual(Number.objects.filter(num__gte=12.1), []) + self.assertQuerysetEqual(Number.objects.filter(num__gte=12.9), []) + self.assertQuerysetEqual( + Number.objects.filter(num__lte=11.9), + ['', ''], + ordered=False + ) + self.assertQuerysetEqual( + Number.objects.filter(num__lte=12), + ['', '', ''], + ordered=False + ) + self.assertQuerysetEqual( + Number.objects.filter(num__lte=12.0), + ['', '', ''], + ordered=False + ) + self.assertQuerysetEqual( + Number.objects.filter(num__lte=12.1), + ['', '', ''], + ordered=False + ) + self.assertQuerysetEqual( + Number.objects.filter(num__lte=12.9), + ['', '', ''], + ordered=False + ) + + def test_ticket7759(self): + # Count should work with a partially read result set. + count = Number.objects.count() + qs = Number.objects.all() + + def run(): + for obj in qs: + return qs.count() == count + self.assertTrue(run()) + + +class Queries3Tests(TestCase): + def test_ticket7107(self): + # This shouldn't create an infinite loop. + self.assertQuerysetEqual(Valid.objects.all(), []) + + def test_ticket8683(self): + # An error should be raised when QuerySet.datetimes() is passed the + # wrong type of field. + with self.assertRaisesMessage(AssertionError, "'name' isn't a DateField, TimeField, or DateTimeField."): + Item.objects.datetimes('name', 'month') + + def test_ticket22023(self): + with self.assertRaisesMessage(TypeError, "Cannot call only() after .values() or .values_list()"): + Valid.objects.values().only() + + with self.assertRaisesMessage(TypeError, "Cannot call defer() after .values() or .values_list()"): + Valid.objects.values().defer() + + +class Queries4Tests(TestCase): + @classmethod + def setUpTestData(cls): + generic = NamedCategory.objects.create(name="Generic") + cls.t1 = Tag.objects.create(name='t1', category=generic) + + n1 = Note.objects.create(note='n1', misc='foo') + n2 = Note.objects.create(note='n2', misc='bar') + + e1 = ExtraInfo.objects.create(info='e1', note=n1) + e2 = ExtraInfo.objects.create(info='e2', note=n2) + + cls.a1 = Author.objects.create(name='a1', num=1001, extra=e1) + cls.a3 = Author.objects.create(name='a3', num=3003, extra=e2) + + cls.r1 = Report.objects.create(name='r1', creator=cls.a1) + cls.r2 = Report.objects.create(name='r2', creator=cls.a3) + cls.r3 = Report.objects.create(name='r3') + + Item.objects.create(name='i1', created=datetime.datetime.now(), note=n1, creator=cls.a1) + Item.objects.create(name='i2', created=datetime.datetime.now(), note=n1, creator=cls.a3) + + def test_ticket24525(self): + tag = Tag.objects.create() + anth100 = tag.note_set.create(note='ANTH', misc='100') + math101 = tag.note_set.create(note='MATH', misc='101') + s1 = tag.annotation_set.create(name='1') + s2 = tag.annotation_set.create(name='2') + s1.notes.set([math101, anth100]) + s2.notes.set([math101]) + result = math101.annotation_set.all() & tag.annotation_set.exclude(notes__in=[anth100]) + self.assertEqual(list(result), [s2]) + + def test_ticket11811(self): + unsaved_category = NamedCategory(name="Other") + msg = 'Unsaved model instance cannot be used in an ORM query.' + with self.assertRaisesMessage(ValueError, msg): + Tag.objects.filter(pk=self.t1.pk).update(category=unsaved_category) + + def test_ticket14876(self): + # Note: when combining the query we need to have information available + # about the join type of the trimmed "creator__isnull" join. If we + # don't have that information, then the join is created as INNER JOIN + # and results will be incorrect. + q1 = Report.objects.filter(Q(creator__isnull=True) | Q(creator__extra__info='e1')) + q2 = Report.objects.filter(Q(creator__isnull=True)) | Report.objects.filter(Q(creator__extra__info='e1')) + self.assertQuerysetEqual(q1, ["", ""], ordered=False) + self.assertEqual(str(q1.query), str(q2.query)) + + q1 = Report.objects.filter(Q(creator__extra__info='e1') | Q(creator__isnull=True)) + q2 = Report.objects.filter(Q(creator__extra__info='e1')) | Report.objects.filter(Q(creator__isnull=True)) + self.assertQuerysetEqual(q1, ["", ""], ordered=False) + self.assertEqual(str(q1.query), str(q2.query)) + + q1 = Item.objects.filter(Q(creator=self.a1) | Q(creator__report__name='r1')).order_by() + q2 = ( + Item.objects + .filter(Q(creator=self.a1)).order_by() | Item.objects.filter(Q(creator__report__name='r1')) + .order_by() + ) + self.assertQuerysetEqual(q1, [""]) + self.assertEqual(str(q1.query), str(q2.query)) + + q1 = Item.objects.filter(Q(creator__report__name='e1') | Q(creator=self.a1)).order_by() + q2 = ( + Item.objects.filter(Q(creator__report__name='e1')).order_by() | + Item.objects.filter(Q(creator=self.a1)).order_by() + ) + self.assertQuerysetEqual(q1, [""]) + self.assertEqual(str(q1.query), str(q2.query)) + + def test_combine_join_reuse(self): + # Joins having identical connections are correctly recreated in the + # rhs query, in case the query is ORed together (#18748). + Report.objects.create(name='r4', creator=self.a1) + q1 = Author.objects.filter(report__name='r5') + q2 = Author.objects.filter(report__name='r4').filter(report__name='r1') + combined = q1 | q2 + self.assertEqual(str(combined.query).count('JOIN'), 2) + self.assertEqual(len(combined), 1) + self.assertEqual(combined[0].name, 'a1') + + def test_join_reuse_order(self): + # Join aliases are reused in order. This shouldn't raise AssertionError + # because change_map contains a circular reference (#26522). + s1 = School.objects.create() + s2 = School.objects.create() + s3 = School.objects.create() + t1 = Teacher.objects.create() + otherteachers = Teacher.objects.exclude(pk=t1.pk).exclude(friends=t1) + qs1 = otherteachers.filter(schools=s1).filter(schools=s2) + qs2 = otherteachers.filter(schools=s1).filter(schools=s3) + self.assertQuerysetEqual(qs1 | qs2, []) + + def test_ticket7095(self): + # Updates that are filtered on the model being updated are somewhat + # tricky in MySQL. + ManagedModel.objects.create(data='mm1', tag=self.t1, public=True) + self.assertEqual(ManagedModel.objects.update(data='mm'), 1) + + # A values() or values_list() query across joined models must use outer + # joins appropriately. + # Note: In Oracle, we expect a null CharField to return '' instead of + # None. + if connection.features.interprets_empty_strings_as_nulls: + expected_null_charfield_repr = '' + else: + expected_null_charfield_repr = None + self.assertSequenceEqual( + Report.objects.values_list("creator__extra__info", flat=True).order_by("name"), + ['e1', 'e2', expected_null_charfield_repr], + ) + + # Similarly for select_related(), joins beyond an initial nullable join + # must use outer joins so that all results are included. + self.assertQuerysetEqual( + Report.objects.select_related("creator", "creator__extra").order_by("name"), + ['', '', ''] + ) + + # When there are multiple paths to a table from another table, we have + # to be careful not to accidentally reuse an inappropriate join when + # using select_related(). We used to return the parent's Detail record + # here by mistake. + + d1 = Detail.objects.create(data="d1") + d2 = Detail.objects.create(data="d2") + m1 = Member.objects.create(name="m1", details=d1) + m2 = Member.objects.create(name="m2", details=d2) + Child.objects.create(person=m2, parent=m1) + obj = m1.children.select_related("person__details")[0] + self.assertEqual(obj.person.details.data, 'd2') + + def test_order_by_resetting(self): + # Calling order_by() with no parameters removes any existing ordering on the + # model. But it should still be possible to add new ordering after that. + qs = Author.objects.order_by().order_by('name') + self.assertIn('ORDER BY', qs.query.get_compiler(qs.db).as_sql()[0]) + + def test_order_by_reverse_fk(self): + # It is possible to order by reverse of foreign key, although that can lead + # to duplicate results. + c1 = SimpleCategory.objects.create(name="category1") + c2 = SimpleCategory.objects.create(name="category2") + CategoryItem.objects.create(category=c1) + CategoryItem.objects.create(category=c2) + CategoryItem.objects.create(category=c1) + self.assertSequenceEqual(SimpleCategory.objects.order_by('categoryitem', 'pk'), [c1, c2, c1]) + + def test_ticket10181(self): + # Avoid raising an EmptyResultSet if an inner query is probably + # empty (and hence, not executed). + self.assertQuerysetEqual( + Tag.objects.filter(id__in=Tag.objects.filter(id__in=[])), + [] + ) + + def test_ticket15316_filter_false(self): + c1 = SimpleCategory.objects.create(name="category1") + c2 = SpecialCategory.objects.create(name="named category1", special_name="special1") + c3 = SpecialCategory.objects.create(name="named category2", special_name="special2") + + CategoryItem.objects.create(category=c1) + ci2 = CategoryItem.objects.create(category=c2) + ci3 = CategoryItem.objects.create(category=c3) + + qs = CategoryItem.objects.filter(category__specialcategory__isnull=False) + self.assertEqual(qs.count(), 2) + self.assertSequenceEqual(qs, [ci2, ci3]) + + def test_ticket15316_exclude_false(self): + c1 = SimpleCategory.objects.create(name="category1") + c2 = SpecialCategory.objects.create(name="named category1", special_name="special1") + c3 = SpecialCategory.objects.create(name="named category2", special_name="special2") + + ci1 = CategoryItem.objects.create(category=c1) + CategoryItem.objects.create(category=c2) + CategoryItem.objects.create(category=c3) + + qs = CategoryItem.objects.exclude(category__specialcategory__isnull=False) + self.assertEqual(qs.count(), 1) + self.assertSequenceEqual(qs, [ci1]) + + def test_ticket15316_filter_true(self): + c1 = SimpleCategory.objects.create(name="category1") + c2 = SpecialCategory.objects.create(name="named category1", special_name="special1") + c3 = SpecialCategory.objects.create(name="named category2", special_name="special2") + + ci1 = CategoryItem.objects.create(category=c1) + CategoryItem.objects.create(category=c2) + CategoryItem.objects.create(category=c3) + + qs = CategoryItem.objects.filter(category__specialcategory__isnull=True) + self.assertEqual(qs.count(), 1) + self.assertSequenceEqual(qs, [ci1]) + + def test_ticket15316_exclude_true(self): + c1 = SimpleCategory.objects.create(name="category1") + c2 = SpecialCategory.objects.create(name="named category1", special_name="special1") + c3 = SpecialCategory.objects.create(name="named category2", special_name="special2") + + CategoryItem.objects.create(category=c1) + ci2 = CategoryItem.objects.create(category=c2) + ci3 = CategoryItem.objects.create(category=c3) + + qs = CategoryItem.objects.exclude(category__specialcategory__isnull=True) + self.assertEqual(qs.count(), 2) + self.assertSequenceEqual(qs, [ci2, ci3]) + + def test_ticket15316_one2one_filter_false(self): + c = SimpleCategory.objects.create(name="cat") + c0 = SimpleCategory.objects.create(name="cat0") + c1 = SimpleCategory.objects.create(name="category1") + + OneToOneCategory.objects.create(category=c1, new_name="new1") + OneToOneCategory.objects.create(category=c0, new_name="new2") + + CategoryItem.objects.create(category=c) + ci2 = CategoryItem.objects.create(category=c0) + ci3 = CategoryItem.objects.create(category=c1) + + qs = CategoryItem.objects.filter(category__onetoonecategory__isnull=False).order_by('pk') + self.assertEqual(qs.count(), 2) + self.assertSequenceEqual(qs, [ci2, ci3]) + + def test_ticket15316_one2one_exclude_false(self): + c = SimpleCategory.objects.create(name="cat") + c0 = SimpleCategory.objects.create(name="cat0") + c1 = SimpleCategory.objects.create(name="category1") + + OneToOneCategory.objects.create(category=c1, new_name="new1") + OneToOneCategory.objects.create(category=c0, new_name="new2") + + ci1 = CategoryItem.objects.create(category=c) + CategoryItem.objects.create(category=c0) + CategoryItem.objects.create(category=c1) + + qs = CategoryItem.objects.exclude(category__onetoonecategory__isnull=False) + self.assertEqual(qs.count(), 1) + self.assertSequenceEqual(qs, [ci1]) + + def test_ticket15316_one2one_filter_true(self): + c = SimpleCategory.objects.create(name="cat") + c0 = SimpleCategory.objects.create(name="cat0") + c1 = SimpleCategory.objects.create(name="category1") + + OneToOneCategory.objects.create(category=c1, new_name="new1") + OneToOneCategory.objects.create(category=c0, new_name="new2") + + ci1 = CategoryItem.objects.create(category=c) + CategoryItem.objects.create(category=c0) + CategoryItem.objects.create(category=c1) + + qs = CategoryItem.objects.filter(category__onetoonecategory__isnull=True) + self.assertEqual(qs.count(), 1) + self.assertSequenceEqual(qs, [ci1]) + + def test_ticket15316_one2one_exclude_true(self): + c = SimpleCategory.objects.create(name="cat") + c0 = SimpleCategory.objects.create(name="cat0") + c1 = SimpleCategory.objects.create(name="category1") + + OneToOneCategory.objects.create(category=c1, new_name="new1") + OneToOneCategory.objects.create(category=c0, new_name="new2") + + CategoryItem.objects.create(category=c) + ci2 = CategoryItem.objects.create(category=c0) + ci3 = CategoryItem.objects.create(category=c1) + + qs = CategoryItem.objects.exclude(category__onetoonecategory__isnull=True).order_by('pk') + self.assertEqual(qs.count(), 2) + self.assertSequenceEqual(qs, [ci2, ci3]) + + +class Queries5Tests(TestCase): + @classmethod + def setUpTestData(cls): + # Ordering by 'rank' gives us rank2, rank1, rank3. Ordering by the + # Meta.ordering will be rank3, rank2, rank1. + n1 = Note.objects.create(note='n1', misc='foo', id=1) + n2 = Note.objects.create(note='n2', misc='bar', id=2) + e1 = ExtraInfo.objects.create(info='e1', note=n1) + e2 = ExtraInfo.objects.create(info='e2', note=n2) + a1 = Author.objects.create(name='a1', num=1001, extra=e1) + a2 = Author.objects.create(name='a2', num=2002, extra=e1) + a3 = Author.objects.create(name='a3', num=3003, extra=e2) + cls.rank1 = Ranking.objects.create(rank=2, author=a2) + Ranking.objects.create(rank=1, author=a3) + Ranking.objects.create(rank=3, author=a1) + + def test_ordering(self): + # Cross model ordering is possible in Meta, too. + self.assertQuerysetEqual( + Ranking.objects.all(), + ['', '', ''] + ) + self.assertQuerysetEqual( + Ranking.objects.all().order_by('rank'), + ['', '', ''] + ) + + # Ordering of extra() pieces is possible, too and you can mix extra + # fields and model fields in the ordering. + self.assertQuerysetEqual( + Ranking.objects.extra(tables=['django_site'], order_by=['-django_site.id', 'rank']), + ['', '', ''] + ) + + sql = 'case when %s > 2 then 1 else 0 end' % connection.ops.quote_name('rank') + qs = Ranking.objects.extra(select={'good': sql}) + self.assertEqual( + [o.good for o in qs.extra(order_by=('-good',))], + [True, False, False] + ) + self.assertQuerysetEqual( + qs.extra(order_by=('-good', 'id')), + ['', '', ''] + ) + + # Despite having some extra aliases in the query, we can still omit + # them in a values() query. + dicts = qs.values('id', 'rank').order_by('id') + self.assertEqual( + [d['rank'] for d in dicts], + [2, 1, 3] + ) + + def test_ticket7256(self): + # An empty values() call includes all aliases, including those from an + # extra() + sql = 'case when %s > 2 then 1 else 0 end' % connection.ops.quote_name('rank') + qs = Ranking.objects.extra(select={'good': sql}) + dicts = qs.values().order_by('id') + for d in dicts: + del d['id'] + del d['author_id'] + self.assertEqual( + [sorted(d.items()) for d in dicts], + [[('good', 0), ('rank', 2)], [('good', 0), ('rank', 1)], [('good', 1), ('rank', 3)]] + ) + + def test_ticket7045(self): + # Extra tables used to crash SQL construction on the second use. + qs = Ranking.objects.extra(tables=['django_site']) + qs.query.get_compiler(qs.db).as_sql() + # test passes if this doesn't raise an exception. + qs.query.get_compiler(qs.db).as_sql() + + def test_ticket9848(self): + # Make sure that updates which only filter on sub-tables don't + # inadvertently update the wrong records (bug #9848). + author_start = Author.objects.get(name='a1') + ranking_start = Ranking.objects.get(author__name='a1') + + # Make sure that the IDs from different tables don't happen to match. + self.assertQuerysetEqual( + Ranking.objects.filter(author__name='a1'), + [''] + ) + self.assertEqual( + Ranking.objects.filter(author__name='a1').update(rank=4636), + 1 + ) + + r = Ranking.objects.get(author__name='a1') + self.assertEqual(r.id, ranking_start.id) + self.assertEqual(r.author.id, author_start.id) + self.assertEqual(r.rank, 4636) + r.rank = 3 + r.save() + self.assertQuerysetEqual( + Ranking.objects.all(), + ['', '', ''] + ) + + def test_ticket5261(self): + # Test different empty excludes. + self.assertQuerysetEqual( + Note.objects.exclude(Q()), + ['', ''] + ) + self.assertQuerysetEqual( + Note.objects.filter(~Q()), + ['', ''] + ) + self.assertQuerysetEqual( + Note.objects.filter(~Q() | ~Q()), + ['', ''] + ) + self.assertQuerysetEqual( + Note.objects.exclude(~Q() & ~Q()), + ['', ''] + ) + + def test_extra_select_literal_percent_s(self): + # Allow %%s to escape select clauses + self.assertEqual( + Note.objects.extra(select={'foo': "'%%s'"})[0].foo, + '%s' + ) + self.assertEqual( + Note.objects.extra(select={'foo': "'%%s bar %%s'"})[0].foo, + '%s bar %s' + ) + self.assertEqual( + Note.objects.extra(select={'foo': "'bar %%s'"})[0].foo, + 'bar %s' + ) + + +class SelectRelatedTests(TestCase): + def test_tickets_3045_3288(self): + # Once upon a time, select_related() with circular relations would loop + # infinitely if you forgot to specify "depth". Now we set an arbitrary + # default upper bound. + self.assertQuerysetEqual(X.objects.all(), []) + self.assertQuerysetEqual(X.objects.select_related(), []) + + +class SubclassFKTests(TestCase): + def test_ticket7778(self): + # Model subclasses could not be deleted if a nullable foreign key + # relates to a model that relates back. + + num_celebs = Celebrity.objects.count() + tvc = TvChef.objects.create(name="Huey") + self.assertEqual(Celebrity.objects.count(), num_celebs + 1) + Fan.objects.create(fan_of=tvc) + Fan.objects.create(fan_of=tvc) + tvc.delete() + + # The parent object should have been deleted as well. + self.assertEqual(Celebrity.objects.count(), num_celebs) + + +class CustomPkTests(TestCase): + def test_ticket7371(self): + self.assertQuerysetEqual(Related.objects.order_by('custom'), []) + + +class NullableRelOrderingTests(TestCase): + def test_ticket10028(self): + # Ordering by model related to nullable relations(!) should use outer + # joins, so that all results are included. + Plaything.objects.create(name="p1") + self.assertQuerysetEqual( + Plaything.objects.all(), + [''] + ) + + def test_join_already_in_query(self): + # Ordering by model related to nullable relations should not change + # the join type of already existing joins. + Plaything.objects.create(name="p1") + s = SingleObject.objects.create(name='s') + r = RelatedObject.objects.create(single=s, f=1) + Plaything.objects.create(name="p2", others=r) + qs = Plaything.objects.all().filter(others__isnull=False).order_by('pk') + self.assertNotIn('JOIN', str(qs.query)) + qs = Plaything.objects.all().filter(others__f__isnull=False).order_by('pk') + self.assertIn('INNER', str(qs.query)) + qs = qs.order_by('others__single__name') + # The ordering by others__single__pk will add one new join (to single) + # and that join must be LEFT join. The already existing join to related + # objects must be kept INNER. So, we have both an INNER and a LEFT join + # in the query. + self.assertEqual(str(qs.query).count('LEFT'), 1) + self.assertEqual(str(qs.query).count('INNER'), 1) + self.assertQuerysetEqual( + qs, + [''] + ) + + +class DisjunctiveFilterTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.n1 = Note.objects.create(note='n1', misc='foo', id=1) + ExtraInfo.objects.create(info='e1', note=cls.n1) + + def test_ticket7872(self): + # Another variation on the disjunctive filtering theme. + + # For the purposes of this regression test, it's important that there is no + # Join object related to the LeafA we create. + LeafA.objects.create(data='first') + self.assertQuerysetEqual(LeafA.objects.all(), ['']) + self.assertQuerysetEqual( + LeafA.objects.filter(Q(data='first') | Q(join__b__data='second')), + [''] + ) + + def test_ticket8283(self): + # Checking that applying filters after a disjunction works correctly. + self.assertQuerysetEqual( + (ExtraInfo.objects.filter(note=self.n1) | ExtraInfo.objects.filter(info='e2')).filter(note=self.n1), + [''] + ) + self.assertQuerysetEqual( + (ExtraInfo.objects.filter(info='e2') | ExtraInfo.objects.filter(note=self.n1)).filter(note=self.n1), + [''] + ) + + +class Queries6Tests(TestCase): + @classmethod + def setUpTestData(cls): + generic = NamedCategory.objects.create(name="Generic") + t1 = Tag.objects.create(name='t1', category=generic) + Tag.objects.create(name='t2', parent=t1, category=generic) + t3 = Tag.objects.create(name='t3', parent=t1) + t4 = Tag.objects.create(name='t4', parent=t3) + Tag.objects.create(name='t5', parent=t3) + n1 = Note.objects.create(note='n1', misc='foo', id=1) + ann1 = Annotation.objects.create(name='a1', tag=t1) + ann1.notes.add(n1) + Annotation.objects.create(name='a2', tag=t4) + + def test_parallel_iterators(self): + # Parallel iterators work. + qs = Tag.objects.all() + i1, i2 = iter(qs), iter(qs) + self.assertEqual(repr(next(i1)), '') + self.assertEqual(repr(next(i1)), '') + self.assertEqual(repr(next(i2)), '') + self.assertEqual(repr(next(i2)), '') + self.assertEqual(repr(next(i2)), '') + self.assertEqual(repr(next(i1)), '') + + qs = X.objects.all() + self.assertFalse(qs) + self.assertFalse(qs) + + def test_nested_queries_sql(self): + # Nested queries should not evaluate the inner query as part of constructing the + # SQL (so we should see a nested query here, indicated by two "SELECT" calls). + qs = Annotation.objects.filter(notes__in=Note.objects.filter(note="xyzzy")) + self.assertEqual( + qs.query.get_compiler(qs.db).as_sql()[0].count('SELECT'), + 2 + ) + + def test_tickets_8921_9188(self): + # Incorrect SQL was being generated for certain types of exclude() + # queries that crossed multi-valued relations (#8921, #9188 and some + # preemptively discovered cases). + + self.assertQuerysetEqual( + PointerA.objects.filter(connection__pointerb__id=1), + [] + ) + self.assertQuerysetEqual( + PointerA.objects.exclude(connection__pointerb__id=1), + [] + ) + + self.assertQuerysetEqual( + Tag.objects.exclude(children=None), + ['', ''] + ) + + # This example is tricky because the parent could be NULL, so only checking + # parents with annotations omits some results (tag t1, in this case). + self.assertQuerysetEqual( + Tag.objects.exclude(parent__annotation__name="a1"), + ['', '', ''] + ) + + # The annotation->tag link is single values and tag->children links is + # multi-valued. So we have to split the exclude filter in the middle + # and then optimize the inner query without losing results. + self.assertQuerysetEqual( + Annotation.objects.exclude(tag__children__name="t2"), + [''] + ) + + # Nested queries are possible (although should be used with care, since + # they have performance problems on backends like MySQL. + self.assertQuerysetEqual( + Annotation.objects.filter(notes__in=Note.objects.filter(note="n1")), + [''] + ) + + def test_ticket3739(self): + # The all() method on querysets returns a copy of the queryset. + q1 = Tag.objects.order_by('name') + self.assertIsNot(q1, q1.all()) + + def test_ticket_11320(self): + qs = Tag.objects.exclude(category=None).exclude(category__name='foo') + self.assertEqual(str(qs.query).count(' INNER JOIN '), 1) + + def test_distinct_ordered_sliced_subquery_aggregation(self): + self.assertEqual(Tag.objects.distinct().order_by('category__name')[:3].count(), 3) + + +class RawQueriesTests(TestCase): + def setUp(self): + Note.objects.create(note='n1', misc='foo', id=1) + + def test_ticket14729(self): + # Test representation of raw query with one or few parameters passed as list + query = "SELECT * FROM queries_note WHERE note = %s" + params = ['n1'] + qs = Note.objects.raw(query, params=params) + self.assertEqual(repr(qs), "") + + query = "SELECT * FROM queries_note WHERE note = %s and misc = %s" + params = ['n1', 'foo'] + qs = Note.objects.raw(query, params=params) + self.assertEqual(repr(qs), "") + + +class GeneratorExpressionTests(TestCase): + def test_ticket10432(self): + # Using an empty generator expression as the rvalue for an "__in" + # lookup is legal. + self.assertQuerysetEqual( + Note.objects.filter(pk__in=(x for x in ())), + [] + ) + + +class ComparisonTests(TestCase): + def setUp(self): + self.n1 = Note.objects.create(note='n1', misc='foo', id=1) + e1 = ExtraInfo.objects.create(info='e1', note=self.n1) + self.a2 = Author.objects.create(name='a2', num=2002, extra=e1) + + def test_ticket8597(self): + # Regression tests for case-insensitive comparisons + Item.objects.create(name="a_b", created=datetime.datetime.now(), creator=self.a2, note=self.n1) + Item.objects.create(name="x%y", created=datetime.datetime.now(), creator=self.a2, note=self.n1) + self.assertQuerysetEqual( + Item.objects.filter(name__iexact="A_b"), + [''] + ) + self.assertQuerysetEqual( + Item.objects.filter(name__iexact="x%Y"), + [''] + ) + self.assertQuerysetEqual( + Item.objects.filter(name__istartswith="A_b"), + [''] + ) + self.assertQuerysetEqual( + Item.objects.filter(name__iendswith="A_b"), + [''] + ) + + +class ExistsSql(TestCase): + def test_exists(self): + with CaptureQueriesContext(connection) as captured_queries: + self.assertFalse(Tag.objects.exists()) + # Ok - so the exist query worked - but did it include too many columns? + self.assertEqual(len(captured_queries), 1) + qstr = captured_queries[0]['sql'] + id, name = connection.ops.quote_name('id'), connection.ops.quote_name('name') + self.assertNotIn(id, qstr) + self.assertNotIn(name, qstr) + + def test_ticket_18414(self): + Article.objects.create(name='one', created=datetime.datetime.now()) + Article.objects.create(name='one', created=datetime.datetime.now()) + Article.objects.create(name='two', created=datetime.datetime.now()) + self.assertTrue(Article.objects.exists()) + self.assertTrue(Article.objects.distinct().exists()) + self.assertTrue(Article.objects.distinct()[1:3].exists()) + self.assertFalse(Article.objects.distinct()[1:1].exists()) + + @skipUnlessDBFeature('can_distinct_on_fields') + def test_ticket_18414_distinct_on(self): + Article.objects.create(name='one', created=datetime.datetime.now()) + Article.objects.create(name='one', created=datetime.datetime.now()) + Article.objects.create(name='two', created=datetime.datetime.now()) + self.assertTrue(Article.objects.distinct('name').exists()) + self.assertTrue(Article.objects.distinct('name')[1:2].exists()) + self.assertFalse(Article.objects.distinct('name')[2:3].exists()) + + +class QuerysetOrderedTests(unittest.TestCase): + """ + Tests for the Queryset.ordered attribute. + """ + + def test_no_default_or_explicit_ordering(self): + self.assertIs(Annotation.objects.all().ordered, False) + + def test_cleared_default_ordering(self): + self.assertIs(Tag.objects.all().ordered, True) + self.assertIs(Tag.objects.all().order_by().ordered, False) + + def test_explicit_ordering(self): + self.assertIs(Annotation.objects.all().order_by('id').ordered, True) + + def test_order_by_extra(self): + self.assertIs(Annotation.objects.all().extra(order_by=['id']).ordered, True) + + def test_annotated_ordering(self): + qs = Annotation.objects.annotate(num_notes=Count('notes')) + self.assertIs(qs.ordered, False) + self.assertIs(qs.order_by('num_notes').ordered, True) + + +@skipUnlessDBFeature('allow_sliced_subqueries_with_in') +class SubqueryTests(TestCase): + @classmethod + def setUpTestData(cls): + NamedCategory.objects.create(id=1, name='first') + NamedCategory.objects.create(id=2, name='second') + NamedCategory.objects.create(id=3, name='third') + NamedCategory.objects.create(id=4, name='fourth') + + def test_ordered_subselect(self): + "Subselects honor any manual ordering" + query = DumbCategory.objects.filter(id__in=DumbCategory.objects.order_by('-id')[0:2]) + self.assertEqual(set(query.values_list('id', flat=True)), {3, 4}) + + query = DumbCategory.objects.filter(id__in=DumbCategory.objects.order_by('-id')[:2]) + self.assertEqual(set(query.values_list('id', flat=True)), {3, 4}) + + query = DumbCategory.objects.filter(id__in=DumbCategory.objects.order_by('-id')[1:2]) + self.assertEqual(set(query.values_list('id', flat=True)), {3}) + + query = DumbCategory.objects.filter(id__in=DumbCategory.objects.order_by('-id')[2:]) + self.assertEqual(set(query.values_list('id', flat=True)), {1, 2}) + + def test_slice_subquery_and_query(self): + """ + Slice a query that has a sliced subquery + """ + query = DumbCategory.objects.filter(id__in=DumbCategory.objects.order_by('-id')[0:2])[0:2] + self.assertEqual({x.id for x in query}, {3, 4}) + + query = DumbCategory.objects.filter(id__in=DumbCategory.objects.order_by('-id')[1:3])[1:3] + self.assertEqual({x.id for x in query}, {3}) + + query = DumbCategory.objects.filter(id__in=DumbCategory.objects.order_by('-id')[2:])[1:] + self.assertEqual({x.id for x in query}, {2}) + + def test_related_sliced_subquery(self): + """ + Related objects constraints can safely contain sliced subqueries. + refs #22434 + """ + generic = NamedCategory.objects.create(id=5, name="Generic") + t1 = Tag.objects.create(name='t1', category=generic) + t2 = Tag.objects.create(name='t2', category=generic) + ManagedModel.objects.create(data='mm1', tag=t1, public=True) + mm2 = ManagedModel.objects.create(data='mm2', tag=t2, public=True) + + query = ManagedModel.normal_manager.filter( + tag__in=Tag.objects.order_by('-id')[:1] + ) + self.assertEqual({x.id for x in query}, {mm2.id}) + + def test_sliced_delete(self): + "Delete queries can safely contain sliced subqueries" + DumbCategory.objects.filter(id__in=DumbCategory.objects.order_by('-id')[0:1]).delete() + self.assertEqual(set(DumbCategory.objects.values_list('id', flat=True)), {1, 2, 3}) + + DumbCategory.objects.filter(id__in=DumbCategory.objects.order_by('-id')[1:2]).delete() + self.assertEqual(set(DumbCategory.objects.values_list('id', flat=True)), {1, 3}) + + DumbCategory.objects.filter(id__in=DumbCategory.objects.order_by('-id')[1:]).delete() + self.assertEqual(set(DumbCategory.objects.values_list('id', flat=True)), {3}) + + def test_distinct_ordered_sliced_subquery(self): + # Implicit values('id'). + self.assertSequenceEqual( + NamedCategory.objects.filter( + id__in=NamedCategory.objects.distinct().order_by('name')[0:2], + ).order_by('name').values_list('name', flat=True), ['first', 'fourth'] + ) + # Explicit values('id'). + self.assertSequenceEqual( + NamedCategory.objects.filter( + id__in=NamedCategory.objects.distinct().order_by('-name').values('id')[0:2], + ).order_by('name').values_list('name', flat=True), ['second', 'third'] + ) + # Annotated value. + self.assertSequenceEqual( + DumbCategory.objects.filter( + id__in=DumbCategory.objects.annotate( + double_id=F('id') * 2 + ).order_by('id').distinct().values('double_id')[0:2], + ).order_by('id').values_list('id', flat=True), [2, 4] + ) + + +class CloneTests(TestCase): + + def test_evaluated_queryset_as_argument(self): + "#13227 -- If a queryset is already evaluated, it can still be used as a query arg" + n = Note(note='Test1', misc='misc') + n.save() + e = ExtraInfo(info='good', note=n) + e.save() + + n_list = Note.objects.all() + # Evaluate the Note queryset, populating the query cache + list(n_list) + # Use the note queryset in a query, and evaluate + # that query in a way that involves cloning. + self.assertEqual(ExtraInfo.objects.filter(note__in=n_list)[0].info, 'good') + + def test_no_model_options_cloning(self): + """ + Cloning a queryset does not get out of hand. While complete + testing is impossible, this is a sanity check against invalid use of + deepcopy. refs #16759. + """ + opts_class = type(Note._meta) + note_deepcopy = getattr(opts_class, "__deepcopy__", None) + opts_class.__deepcopy__ = lambda obj, memo: self.fail("Model options shouldn't be cloned.") + try: + Note.objects.filter(pk__lte=F('pk') + 1).all() + finally: + if note_deepcopy is None: + delattr(opts_class, "__deepcopy__") + else: + opts_class.__deepcopy__ = note_deepcopy + + def test_no_fields_cloning(self): + """ + Cloning a queryset does not get out of hand. While complete + testing is impossible, this is a sanity check against invalid use of + deepcopy. refs #16759. + """ + opts_class = type(Note._meta.get_field("misc")) + note_deepcopy = getattr(opts_class, "__deepcopy__", None) + opts_class.__deepcopy__ = lambda obj, memo: self.fail("Model fields shouldn't be cloned") + try: + Note.objects.filter(note=F('misc')).all() + finally: + if note_deepcopy is None: + delattr(opts_class, "__deepcopy__") + else: + opts_class.__deepcopy__ = note_deepcopy + + +class EmptyQuerySetTests(TestCase): + def test_emptyqueryset_values(self): + # #14366 -- Calling .values() on an empty QuerySet and then cloning + # that should not cause an error + self.assertQuerysetEqual( + Number.objects.none().values('num').order_by('num'), [] + ) + + def test_values_subquery(self): + self.assertQuerysetEqual( + Number.objects.filter(pk__in=Number.objects.none().values("pk")), + [] + ) + self.assertQuerysetEqual( + Number.objects.filter(pk__in=Number.objects.none().values_list("pk")), + [] + ) + + def test_ticket_19151(self): + # #19151 -- Calling .values() or .values_list() on an empty QuerySet + # should return an empty QuerySet and not cause an error. + q = Author.objects.none() + self.assertQuerysetEqual(q.values(), []) + self.assertQuerysetEqual(q.values_list(), []) + + +class ValuesQuerysetTests(TestCase): + @classmethod + def setUpTestData(cls): + Number.objects.create(num=72) + + def test_flat_values_list(self): + qs = Number.objects.values_list("num") + qs = qs.values_list("num", flat=True) + self.assertSequenceEqual(qs, [72]) + + def test_extra_values(self): + # testing for ticket 14930 issues + qs = Number.objects.extra(select=OrderedDict([('value_plus_x', 'num+%s'), + ('value_minus_x', 'num-%s')]), + select_params=(1, 2)) + qs = qs.order_by('value_minus_x') + qs = qs.values('num') + self.assertSequenceEqual(qs, [{'num': 72}]) + + def test_extra_values_order_twice(self): + # testing for ticket 14930 issues + qs = Number.objects.extra(select={'value_plus_one': 'num+1', 'value_minus_one': 'num-1'}) + qs = qs.order_by('value_minus_one').order_by('value_plus_one') + qs = qs.values('num') + self.assertSequenceEqual(qs, [{'num': 72}]) + + def test_extra_values_order_multiple(self): + # Postgres doesn't allow constants in order by, so check for that. + qs = Number.objects.extra(select={ + 'value_plus_one': 'num+1', + 'value_minus_one': 'num-1', + 'constant_value': '1' + }) + qs = qs.order_by('value_plus_one', 'value_minus_one', 'constant_value') + qs = qs.values('num') + self.assertSequenceEqual(qs, [{'num': 72}]) + + def test_extra_values_order_in_extra(self): + # testing for ticket 14930 issues + qs = Number.objects.extra( + select={'value_plus_one': 'num+1', 'value_minus_one': 'num-1'}, + order_by=['value_minus_one'], + ) + qs = qs.values('num') + + def test_extra_select_params_values_order_in_extra(self): + # testing for 23259 issue + qs = Number.objects.extra( + select={'value_plus_x': 'num+%s'}, + select_params=[1], + order_by=['value_plus_x'], + ) + qs = qs.filter(num=72) + qs = qs.values('num') + self.assertSequenceEqual(qs, [{'num': 72}]) + + def test_extra_multiple_select_params_values_order_by(self): + # testing for 23259 issue + qs = Number.objects.extra(select=OrderedDict([('value_plus_x', 'num+%s'), + ('value_minus_x', 'num-%s')]), + select_params=(72, 72)) + qs = qs.order_by('value_minus_x') + qs = qs.filter(num=1) + qs = qs.values('num') + self.assertSequenceEqual(qs, []) + + def test_extra_values_list(self): + # testing for ticket 14930 issues + qs = Number.objects.extra(select={'value_plus_one': 'num+1'}) + qs = qs.order_by('value_plus_one') + qs = qs.values_list('num') + self.assertSequenceEqual(qs, [(72,)]) + + def test_flat_extra_values_list(self): + # testing for ticket 14930 issues + qs = Number.objects.extra(select={'value_plus_one': 'num+1'}) + qs = qs.order_by('value_plus_one') + qs = qs.values_list('num', flat=True) + self.assertSequenceEqual(qs, [72]) + + def test_field_error_values_list(self): + # see #23443 + msg = "Cannot resolve keyword %r into field. Join on 'name' not permitted." % 'foo' + with self.assertRaisesMessage(FieldError, msg): + Tag.objects.values_list('name__foo') + + def test_named_values_list_flat(self): + msg = "'flat' and 'named' can't be used together." + with self.assertRaisesMessage(TypeError, msg): + Number.objects.values_list('num', flat=True, named=True) + + def test_named_values_list_bad_field_name(self): + msg = "Type names and field names must be valid identifiers: '1'" + with self.assertRaisesMessage(ValueError, msg): + Number.objects.extra(select={'1': 'num+1'}).values_list('1', named=True).first() + + def test_named_values_list_with_fields(self): + qs = Number.objects.extra(select={'num2': 'num+1'}).annotate(Count('id')) + values = qs.values_list('num', 'num2', named=True).first() + self.assertEqual(type(values).__name__, 'Row') + self.assertEqual(values._fields, ('num', 'num2')) + self.assertEqual(values.num, 72) + self.assertEqual(values.num2, 73) + + def test_named_values_list_without_fields(self): + qs = Number.objects.extra(select={'num2': 'num+1'}).annotate(Count('id')) + values = qs.values_list(named=True).first() + self.assertEqual(type(values).__name__, 'Row') + self.assertEqual(values._fields, ('num2', 'id', 'num', 'id__count')) + self.assertEqual(values.num, 72) + self.assertEqual(values.num2, 73) + self.assertEqual(values.id__count, 1) + + def test_named_values_list_expression_with_default_alias(self): + expr = Count('id') + values = Number.objects.annotate(id__count1=expr).values_list(expr, 'id__count1', named=True).first() + self.assertEqual(values._fields, ('id__count2', 'id__count1')) + + def test_named_values_list_expression(self): + expr = F('num') + 1 + qs = Number.objects.annotate(combinedexpression1=expr).values_list(expr, 'combinedexpression1', named=True) + values = qs.first() + self.assertEqual(values._fields, ('combinedexpression2', 'combinedexpression1')) + + +class QuerySetSupportsPythonIdioms(TestCase): + + @classmethod + def setUpTestData(cls): + some_date = datetime.datetime(2014, 5, 16, 12, 1) + for i in range(1, 8): + Article.objects.create( + name="Article {}".format(i), created=some_date) + + def get_ordered_articles(self): + return Article.objects.all().order_by('name') + + def test_can_get_items_using_index_and_slice_notation(self): + self.assertEqual(self.get_ordered_articles()[0].name, 'Article 1') + self.assertQuerysetEqual( + self.get_ordered_articles()[1:3], + ["", ""] + ) + + def test_slicing_with_steps_can_be_used(self): + self.assertQuerysetEqual( + self.get_ordered_articles()[::2], [ + "", + "", + "", + "" + ] + ) + + def test_slicing_without_step_is_lazy(self): + with self.assertNumQueries(0): + self.get_ordered_articles()[0:5] + + def test_slicing_with_tests_is_not_lazy(self): + with self.assertNumQueries(1): + self.get_ordered_articles()[0:5:3] + + def test_slicing_can_slice_again_after_slicing(self): + self.assertQuerysetEqual( + self.get_ordered_articles()[0:5][0:2], + ["", ""] + ) + self.assertQuerysetEqual(self.get_ordered_articles()[0:5][4:], [""]) + self.assertQuerysetEqual(self.get_ordered_articles()[0:5][5:], []) + + # Some more tests! + self.assertQuerysetEqual( + self.get_ordered_articles()[2:][0:2], + ["", ""] + ) + self.assertQuerysetEqual( + self.get_ordered_articles()[2:][:2], + ["", ""] + ) + self.assertQuerysetEqual(self.get_ordered_articles()[2:][2:3], [""]) + + # Using an offset without a limit is also possible. + self.assertQuerysetEqual( + self.get_ordered_articles()[5:], + ["", ""] + ) + + def test_slicing_cannot_filter_queryset_once_sliced(self): + with self.assertRaisesMessage(AssertionError, "Cannot filter a query once a slice has been taken."): + Article.objects.all()[0:5].filter(id=1) + + def test_slicing_cannot_reorder_queryset_once_sliced(self): + with self.assertRaisesMessage(AssertionError, "Cannot reorder a query once a slice has been taken."): + Article.objects.all()[0:5].order_by('id') + + def test_slicing_cannot_combine_queries_once_sliced(self): + with self.assertRaisesMessage(AssertionError, "Cannot combine queries once a slice has been taken."): + Article.objects.all()[0:1] & Article.objects.all()[4:5] + + def test_slicing_negative_indexing_not_supported_for_single_element(self): + """hint: inverting your ordering might do what you need""" + with self.assertRaisesMessage(AssertionError, "Negative indexing is not supported."): + Article.objects.all()[-1] + + def test_slicing_negative_indexing_not_supported_for_range(self): + """hint: inverting your ordering might do what you need""" + with self.assertRaisesMessage(AssertionError, "Negative indexing is not supported."): + Article.objects.all()[0:-5] + + def test_can_get_number_of_items_in_queryset_using_standard_len(self): + self.assertEqual(len(Article.objects.filter(name__exact='Article 1')), 1) + + def test_can_combine_queries_using_and_and_or_operators(self): + s1 = Article.objects.filter(name__exact='Article 1') + s2 = Article.objects.filter(name__exact='Article 2') + self.assertQuerysetEqual( + (s1 | s2).order_by('name'), + ["", ""] + ) + self.assertQuerysetEqual(s1 & s2, []) + + +class WeirdQuerysetSlicingTests(TestCase): + @classmethod + def setUpTestData(cls): + Number.objects.create(num=1) + Number.objects.create(num=2) + + Article.objects.create(name='one', created=datetime.datetime.now()) + Article.objects.create(name='two', created=datetime.datetime.now()) + Article.objects.create(name='three', created=datetime.datetime.now()) + Article.objects.create(name='four', created=datetime.datetime.now()) + + food = Food.objects.create(name='spam') + Eaten.objects.create(meal='spam with eggs', food=food) + + def test_tickets_7698_10202(self): + # People like to slice with '0' as the high-water mark. + self.assertQuerysetEqual(Article.objects.all()[0:0], []) + self.assertQuerysetEqual(Article.objects.all()[0:0][:10], []) + self.assertEqual(Article.objects.all()[:0].count(), 0) + with self.assertRaisesMessage(TypeError, 'Cannot reverse a query once a slice has been taken.'): + Article.objects.all()[:0].latest('created') + + def test_empty_resultset_sql(self): + # ticket #12192 + self.assertNumQueries(0, lambda: list(Number.objects.all()[1:1])) + + def test_empty_sliced_subquery(self): + self.assertEqual(Eaten.objects.filter(food__in=Food.objects.all()[0:0]).count(), 0) + + def test_empty_sliced_subquery_exclude(self): + self.assertEqual(Eaten.objects.exclude(food__in=Food.objects.all()[0:0]).count(), 1) + + def test_zero_length_values_slicing(self): + n = 42 + with self.assertNumQueries(0): + self.assertQuerysetEqual(Article.objects.values()[n:n], []) + self.assertQuerysetEqual(Article.objects.values_list()[n:n], []) + + +class EscapingTests(TestCase): + def test_ticket_7302(self): + # Reserved names are appropriately escaped + ReservedName.objects.create(name='a', order=42) + ReservedName.objects.create(name='b', order=37) + self.assertQuerysetEqual( + ReservedName.objects.all().order_by('order'), + ['', ''] + ) + self.assertQuerysetEqual( + ReservedName.objects.extra(select={'stuff': 'name'}, order_by=('order', 'stuff')), + ['', ''] + ) + + +class ToFieldTests(TestCase): + def test_in_query(self): + apple = Food.objects.create(name="apple") + pear = Food.objects.create(name="pear") + lunch = Eaten.objects.create(food=apple, meal="lunch") + dinner = Eaten.objects.create(food=pear, meal="dinner") + + self.assertEqual( + set(Eaten.objects.filter(food__in=[apple, pear])), + {lunch, dinner}, + ) + + def test_in_subquery(self): + apple = Food.objects.create(name="apple") + lunch = Eaten.objects.create(food=apple, meal="lunch") + self.assertEqual( + set(Eaten.objects.filter(food__in=Food.objects.filter(name='apple'))), + {lunch} + ) + self.assertEqual( + set(Eaten.objects.filter(food__in=Food.objects.filter(name='apple').values('eaten__meal'))), + set() + ) + self.assertEqual( + set(Food.objects.filter(eaten__in=Eaten.objects.filter(meal='lunch'))), + {apple} + ) + + def test_nested_in_subquery(self): + extra = ExtraInfo.objects.create() + author = Author.objects.create(num=42, extra=extra) + report = Report.objects.create(creator=author) + comment = ReportComment.objects.create(report=report) + comments = ReportComment.objects.filter( + report__in=Report.objects.filter( + creator__in=extra.author_set.all(), + ), + ) + self.assertSequenceEqual(comments, [comment]) + + def test_reverse_in(self): + apple = Food.objects.create(name="apple") + pear = Food.objects.create(name="pear") + lunch_apple = Eaten.objects.create(food=apple, meal="lunch") + lunch_pear = Eaten.objects.create(food=pear, meal="dinner") + + self.assertEqual( + set(Food.objects.filter(eaten__in=[lunch_apple, lunch_pear])), + {apple, pear} + ) + + def test_single_object(self): + apple = Food.objects.create(name="apple") + lunch = Eaten.objects.create(food=apple, meal="lunch") + dinner = Eaten.objects.create(food=apple, meal="dinner") + + self.assertEqual( + set(Eaten.objects.filter(food=apple)), + {lunch, dinner} + ) + + def test_single_object_reverse(self): + apple = Food.objects.create(name="apple") + lunch = Eaten.objects.create(food=apple, meal="lunch") + + self.assertEqual( + set(Food.objects.filter(eaten=lunch)), + {apple} + ) + + def test_recursive_fk(self): + node1 = Node.objects.create(num=42) + node2 = Node.objects.create(num=1, parent=node1) + + self.assertEqual( + list(Node.objects.filter(parent=node1)), + [node2] + ) + + def test_recursive_fk_reverse(self): + node1 = Node.objects.create(num=42) + node2 = Node.objects.create(num=1, parent=node1) + + self.assertEqual( + list(Node.objects.filter(node=node2)), + [node1] + ) + + +class IsNullTests(TestCase): + def test_primary_key(self): + custom = CustomPk.objects.create(name='pk') + null = Related.objects.create() + notnull = Related.objects.create(custom=custom) + self.assertSequenceEqual(Related.objects.filter(custom__isnull=False), [notnull]) + self.assertSequenceEqual(Related.objects.filter(custom__isnull=True), [null]) + + def test_to_field(self): + apple = Food.objects.create(name="apple") + Eaten.objects.create(food=apple, meal="lunch") + Eaten.objects.create(meal="lunch") + self.assertQuerysetEqual( + Eaten.objects.filter(food__isnull=False), + [''] + ) + self.assertQuerysetEqual( + Eaten.objects.filter(food__isnull=True), + [''] + ) + + +class ConditionalTests(TestCase): + """Tests whose execution depend on different environment conditions like + Python version or DB backend features""" + + @classmethod + def setUpTestData(cls): + generic = NamedCategory.objects.create(name="Generic") + t1 = Tag.objects.create(name='t1', category=generic) + Tag.objects.create(name='t2', parent=t1, category=generic) + t3 = Tag.objects.create(name='t3', parent=t1) + Tag.objects.create(name='t4', parent=t3) + Tag.objects.create(name='t5', parent=t3) + + def test_infinite_loop(self): + # If you're not careful, it's possible to introduce infinite loops via + # default ordering on foreign keys in a cycle. We detect that. + with self.assertRaisesMessage(FieldError, 'Infinite loop caused by ordering.'): + list(LoopX.objects.all()) # Force queryset evaluation with list() + with self.assertRaisesMessage(FieldError, 'Infinite loop caused by ordering.'): + list(LoopZ.objects.all()) # Force queryset evaluation with list() + + # Note that this doesn't cause an infinite loop, since the default + # ordering on the Tag model is empty (and thus defaults to using "id" + # for the related field). + self.assertEqual(len(Tag.objects.order_by('parent')), 5) + + # ... but you can still order in a non-recursive fashion among linked + # fields (the previous test failed because the default ordering was + # recursive). + self.assertQuerysetEqual( + LoopX.objects.all().order_by('y__x__y__x__id'), + [] + ) + + # When grouping without specifying ordering, we add an explicit "ORDER BY NULL" + # portion in MySQL to prevent unnecessary sorting. + @skipUnlessDBFeature('requires_explicit_null_ordering_when_grouping') + def test_null_ordering_added(self): + query = Tag.objects.values_list('parent_id', flat=True).order_by().query + query.group_by = ['parent_id'] + sql = query.get_compiler(DEFAULT_DB_ALIAS).as_sql()[0] + fragment = "ORDER BY " + pos = sql.find(fragment) + self.assertEqual(sql.find(fragment, pos + 1), -1) + self.assertEqual(sql.find("NULL", pos + len(fragment)), pos + len(fragment)) + + def test_in_list_limit(self): + # The "in" lookup works with lists of 1000 items or more. + # The numbers amount is picked to force three different IN batches + # for Oracle, yet to be less than 2100 parameter limit for MSSQL. + numbers = list(range(2050)) + max_query_params = connection.features.max_query_params + if max_query_params is None or max_query_params >= len(numbers): + Number.objects.bulk_create(Number(num=num) for num in numbers) + for number in [1000, 1001, 2000, len(numbers)]: + with self.subTest(number=number): + self.assertEqual(Number.objects.filter(num__in=numbers[:number]).count(), number) + + +class UnionTests(unittest.TestCase): + """ + Tests for the union of two querysets. Bug #12252. + """ + @classmethod + def setUpTestData(cls): + objectas = [] + objectbs = [] + objectcs = [] + a_info = ['one', 'two', 'three'] + for name in a_info: + o = ObjectA(name=name) + o.save() + objectas.append(o) + b_info = [('un', 1, objectas[0]), ('deux', 2, objectas[0]), ('trois', 3, objectas[2])] + for name, number, objecta in b_info: + o = ObjectB(name=name, num=number, objecta=objecta) + o.save() + objectbs.append(o) + c_info = [('ein', objectas[2], objectbs[2]), ('zwei', objectas[1], objectbs[1])] + for name, objecta, objectb in c_info: + o = ObjectC(name=name, objecta=objecta, objectb=objectb) + o.save() + objectcs.append(o) + + def check_union(self, model, Q1, Q2): + filter = model.objects.filter + self.assertEqual(set(filter(Q1) | filter(Q2)), set(filter(Q1 | Q2))) + self.assertEqual(set(filter(Q2) | filter(Q1)), set(filter(Q1 | Q2))) + + def test_A_AB(self): + Q1 = Q(name='two') + Q2 = Q(objectb__name='deux') + self.check_union(ObjectA, Q1, Q2) + + def test_A_AB2(self): + Q1 = Q(name='two') + Q2 = Q(objectb__name='deux', objectb__num=2) + self.check_union(ObjectA, Q1, Q2) + + def test_AB_ACB(self): + Q1 = Q(objectb__name='deux') + Q2 = Q(objectc__objectb__name='deux') + self.check_union(ObjectA, Q1, Q2) + + def test_BAB_BAC(self): + Q1 = Q(objecta__objectb__name='deux') + Q2 = Q(objecta__objectc__name='ein') + self.check_union(ObjectB, Q1, Q2) + + def test_BAB_BACB(self): + Q1 = Q(objecta__objectb__name='deux') + Q2 = Q(objecta__objectc__objectb__name='trois') + self.check_union(ObjectB, Q1, Q2) + + def test_BA_BCA__BAB_BAC_BCA(self): + Q1 = Q(objecta__name='one', objectc__objecta__name='two') + Q2 = Q(objecta__objectc__name='ein', objectc__objecta__name='three', objecta__objectb__name='trois') + self.check_union(ObjectB, Q1, Q2) + + +class DefaultValuesInsertTest(TestCase): + def test_no_extra_params(self): + """ + Can create an instance of a model with only the PK field (#17056)." + """ + DumbCategory.objects.create() + + +class ExcludeTests(TestCase): + @classmethod + def setUpTestData(cls): + f1 = Food.objects.create(name='apples') + Food.objects.create(name='oranges') + Eaten.objects.create(food=f1, meal='dinner') + j1 = Job.objects.create(name='Manager') + r1 = Responsibility.objects.create(description='Playing golf') + j2 = Job.objects.create(name='Programmer') + r2 = Responsibility.objects.create(description='Programming') + JobResponsibilities.objects.create(job=j1, responsibility=r1) + JobResponsibilities.objects.create(job=j2, responsibility=r2) + + def test_to_field(self): + self.assertQuerysetEqual( + Food.objects.exclude(eaten__meal='dinner'), + ['']) + self.assertQuerysetEqual( + Job.objects.exclude(responsibilities__description='Playing golf'), + ['']) + self.assertQuerysetEqual( + Responsibility.objects.exclude(jobs__name='Manager'), + ['']) + + def test_ticket14511(self): + alex = Person.objects.get_or_create(name='Alex')[0] + jane = Person.objects.get_or_create(name='Jane')[0] + + oracle = Company.objects.get_or_create(name='Oracle')[0] + google = Company.objects.get_or_create(name='Google')[0] + microsoft = Company.objects.get_or_create(name='Microsoft')[0] + intel = Company.objects.get_or_create(name='Intel')[0] + + def employ(employer, employee, title): + Employment.objects.get_or_create(employee=employee, employer=employer, title=title) + + employ(oracle, alex, 'Engineer') + employ(oracle, alex, 'Developer') + employ(google, alex, 'Engineer') + employ(google, alex, 'Manager') + employ(microsoft, alex, 'Manager') + employ(intel, alex, 'Manager') + + employ(microsoft, jane, 'Developer') + employ(intel, jane, 'Manager') + + alex_tech_employers = alex.employers.filter( + employment__title__in=('Engineer', 'Developer')).distinct().order_by('name') + self.assertSequenceEqual(alex_tech_employers, [google, oracle]) + + alex_nontech_employers = alex.employers.exclude( + employment__title__in=('Engineer', 'Developer')).distinct().order_by('name') + self.assertSequenceEqual(alex_nontech_employers, [google, intel, microsoft]) + + +class ExcludeTest17600(TestCase): + """ + Some regressiontests for ticket #17600. Some of these likely duplicate + other existing tests. + """ + @classmethod + def setUpTestData(cls): + # Create a few Orders. + cls.o1 = Order.objects.create(pk=1) + cls.o2 = Order.objects.create(pk=2) + cls.o3 = Order.objects.create(pk=3) + + # Create some OrderItems for the first order with homogeneous + # status_id values + cls.oi1 = OrderItem.objects.create(order=cls.o1, status=1) + cls.oi2 = OrderItem.objects.create(order=cls.o1, status=1) + cls.oi3 = OrderItem.objects.create(order=cls.o1, status=1) + + # Create some OrderItems for the second order with heterogeneous + # status_id values + cls.oi4 = OrderItem.objects.create(order=cls.o2, status=1) + cls.oi5 = OrderItem.objects.create(order=cls.o2, status=2) + cls.oi6 = OrderItem.objects.create(order=cls.o2, status=3) + + # Create some OrderItems for the second order with heterogeneous + # status_id values + cls.oi7 = OrderItem.objects.create(order=cls.o3, status=2) + cls.oi8 = OrderItem.objects.create(order=cls.o3, status=3) + cls.oi9 = OrderItem.objects.create(order=cls.o3, status=4) + + def test_exclude_plain(self): + """ + This should exclude Orders which have some items with status 1 + """ + self.assertQuerysetEqual( + Order.objects.exclude(items__status=1), + ['']) + + def test_exclude_plain_distinct(self): + """ + This should exclude Orders which have some items with status 1 + """ + self.assertQuerysetEqual( + Order.objects.exclude(items__status=1).distinct(), + ['']) + + def test_exclude_with_q_object_distinct(self): + """ + This should exclude Orders which have some items with status 1 + """ + self.assertQuerysetEqual( + Order.objects.exclude(Q(items__status=1)).distinct(), + ['']) + + def test_exclude_with_q_object_no_distinct(self): + """ + This should exclude Orders which have some items with status 1 + """ + self.assertQuerysetEqual( + Order.objects.exclude(Q(items__status=1)), + ['']) + + def test_exclude_with_q_is_equal_to_plain_exclude(self): + """ + Using exclude(condition) and exclude(Q(condition)) should + yield the same QuerySet + """ + self.assertEqual( + list(Order.objects.exclude(items__status=1).distinct()), + list(Order.objects.exclude(Q(items__status=1)).distinct())) + + def test_exclude_with_q_is_equal_to_plain_exclude_variation(self): + """ + Using exclude(condition) and exclude(Q(condition)) should + yield the same QuerySet + """ + self.assertEqual( + list(Order.objects.exclude(items__status=1)), + list(Order.objects.exclude(Q(items__status=1)).distinct())) + + @unittest.expectedFailure + def test_only_orders_with_all_items_having_status_1(self): + """ + This should only return orders having ALL items set to status 1, or + those items not having any orders at all. The correct way to write + this query in SQL seems to be using two nested subqueries. + """ + self.assertQuerysetEqual( + Order.objects.exclude(~Q(items__status=1)).distinct(), + ['']) + + +class Exclude15786(TestCase): + """Regression test for #15786""" + def test_ticket15786(self): + c1 = SimpleCategory.objects.create(name='c1') + c2 = SimpleCategory.objects.create(name='c2') + OneToOneCategory.objects.create(category=c1) + OneToOneCategory.objects.create(category=c2) + rel = CategoryRelationship.objects.create(first=c1, second=c2) + self.assertEqual( + CategoryRelationship.objects.exclude( + first__onetoonecategory=F('second__onetoonecategory') + ).get(), rel + ) + + +class NullInExcludeTest(TestCase): + @classmethod + def setUpTestData(cls): + NullableName.objects.create(name='i1') + NullableName.objects.create() + + def test_null_in_exclude_qs(self): + none_val = '' if connection.features.interprets_empty_strings_as_nulls else None + self.assertQuerysetEqual( + NullableName.objects.exclude(name__in=[]), + ['i1', none_val], attrgetter('name')) + self.assertQuerysetEqual( + NullableName.objects.exclude(name__in=['i1']), + [none_val], attrgetter('name')) + self.assertQuerysetEqual( + NullableName.objects.exclude(name__in=['i3']), + ['i1', none_val], attrgetter('name')) + inner_qs = NullableName.objects.filter(name='i1').values_list('name') + self.assertQuerysetEqual( + NullableName.objects.exclude(name__in=inner_qs), + [none_val], attrgetter('name')) + # The inner queryset wasn't executed - it should be turned + # into subquery above + self.assertIs(inner_qs._result_cache, None) + + @unittest.expectedFailure + def test_col_not_in_list_containing_null(self): + """ + The following case is not handled properly because + SQL's COL NOT IN (list containing null) handling is too weird to + abstract away. + """ + self.assertQuerysetEqual( + NullableName.objects.exclude(name__in=[None]), + ['i1'], attrgetter('name')) + + def test_double_exclude(self): + self.assertEqual( + list(NullableName.objects.filter(~~Q(name='i1'))), + list(NullableName.objects.filter(Q(name='i1')))) + self.assertNotIn( + 'IS NOT NULL', + str(NullableName.objects.filter(~~Q(name='i1')).query)) + + +class EmptyStringsAsNullTest(TestCase): + """ + Filtering on non-null character fields works as expected. + The reason for these tests is that Oracle treats '' as NULL, and this + can cause problems in query construction. Refs #17957. + """ + @classmethod + def setUpTestData(cls): + cls.nc = NamedCategory.objects.create(name='') + + def test_direct_exclude(self): + self.assertQuerysetEqual( + NamedCategory.objects.exclude(name__in=['nonexistent']), + [self.nc.pk], attrgetter('pk') + ) + + def test_joined_exclude(self): + self.assertQuerysetEqual( + DumbCategory.objects.exclude(namedcategory__name__in=['nonexistent']), + [self.nc.pk], attrgetter('pk') + ) + + def test_21001(self): + foo = NamedCategory.objects.create(name='foo') + self.assertQuerysetEqual( + NamedCategory.objects.exclude(name=''), + [foo.pk], attrgetter('pk') + ) + + +class ProxyQueryCleanupTest(TestCase): + def test_evaluated_proxy_count(self): + """ + Generating the query string doesn't alter the query's state + in irreversible ways. Refs #18248. + """ + ProxyCategory.objects.create() + qs = ProxyCategory.objects.all() + self.assertEqual(qs.count(), 1) + str(qs.query) + self.assertEqual(qs.count(), 1) + + +class WhereNodeTest(TestCase): + class DummyNode: + def as_sql(self, compiler, connection): + return 'dummy', [] + + class MockCompiler: + def compile(self, node): + return node.as_sql(self, connection) + + def __call__(self, name): + return connection.ops.quote_name(name) + + def test_empty_full_handling_conjunction(self): + compiler = WhereNodeTest.MockCompiler() + w = WhereNode(children=[NothingNode()]) + with self.assertRaises(EmptyResultSet): + w.as_sql(compiler, connection) + w.negate() + self.assertEqual(w.as_sql(compiler, connection), ('', [])) + w = WhereNode(children=[self.DummyNode(), self.DummyNode()]) + self.assertEqual(w.as_sql(compiler, connection), ('(dummy AND dummy)', [])) + w.negate() + self.assertEqual(w.as_sql(compiler, connection), ('NOT (dummy AND dummy)', [])) + w = WhereNode(children=[NothingNode(), self.DummyNode()]) + with self.assertRaises(EmptyResultSet): + w.as_sql(compiler, connection) + w.negate() + self.assertEqual(w.as_sql(compiler, connection), ('', [])) + + def test_empty_full_handling_disjunction(self): + compiler = WhereNodeTest.MockCompiler() + w = WhereNode(children=[NothingNode()], connector='OR') + with self.assertRaises(EmptyResultSet): + w.as_sql(compiler, connection) + w.negate() + self.assertEqual(w.as_sql(compiler, connection), ('', [])) + w = WhereNode(children=[self.DummyNode(), self.DummyNode()], connector='OR') + self.assertEqual(w.as_sql(compiler, connection), ('(dummy OR dummy)', [])) + w.negate() + self.assertEqual(w.as_sql(compiler, connection), ('NOT (dummy OR dummy)', [])) + w = WhereNode(children=[NothingNode(), self.DummyNode()], connector='OR') + self.assertEqual(w.as_sql(compiler, connection), ('dummy', [])) + w.negate() + self.assertEqual(w.as_sql(compiler, connection), ('NOT (dummy)', [])) + + def test_empty_nodes(self): + compiler = WhereNodeTest.MockCompiler() + empty_w = WhereNode() + w = WhereNode(children=[empty_w, empty_w]) + self.assertEqual(w.as_sql(compiler, connection), ('', [])) + w.negate() + with self.assertRaises(EmptyResultSet): + w.as_sql(compiler, connection) + w.connector = 'OR' + with self.assertRaises(EmptyResultSet): + w.as_sql(compiler, connection) + w.negate() + self.assertEqual(w.as_sql(compiler, connection), ('', [])) + w = WhereNode(children=[empty_w, NothingNode()], connector='OR') + self.assertEqual(w.as_sql(compiler, connection), ('', [])) + w = WhereNode(children=[empty_w, NothingNode()], connector='AND') + with self.assertRaises(EmptyResultSet): + w.as_sql(compiler, connection) + + +class QuerySetExceptionTests(TestCase): + def test_iter_exceptions(self): + qs = ExtraInfo.objects.only('author') + msg = "'ManyToOneRel' object has no attribute 'attname'" + with self.assertRaisesMessage(AttributeError, msg): + list(qs) + + def test_invalid_qs_list(self): + # Test for #19895 - second iteration over invalid queryset + # raises errors. + qs = Article.objects.order_by('invalid_column') + msg = "Cannot resolve keyword 'invalid_column' into field." + with self.assertRaisesMessage(FieldError, msg): + list(qs) + with self.assertRaisesMessage(FieldError, msg): + list(qs) + + def test_invalid_order_by(self): + msg = "Invalid order_by arguments: ['*']" + with self.assertRaisesMessage(FieldError, msg): + list(Article.objects.order_by('*')) + + def test_invalid_queryset_model(self): + msg = 'Cannot use QuerySet for "Article": Use a QuerySet for "ExtraInfo".' + with self.assertRaisesMessage(ValueError, msg): + list(Author.objects.filter(extra=Article.objects.all())) + + +class NullJoinPromotionOrTest(TestCase): + @classmethod + def setUpTestData(cls): + cls.d1 = ModelD.objects.create(name='foo') + d2 = ModelD.objects.create(name='bar') + cls.a1 = ModelA.objects.create(name='a1', d=cls.d1) + c = ModelC.objects.create(name='c') + b = ModelB.objects.create(name='b', c=c) + cls.a2 = ModelA.objects.create(name='a2', b=b, d=d2) + + def test_ticket_17886(self): + # The first Q-object is generating the match, the rest of the filters + # should not remove the match even if they do not match anything. The + # problem here was that b__name generates a LOUTER JOIN, then + # b__c__name generates join to c, which the ORM tried to promote but + # failed as that join isn't nullable. + q_obj = ( + Q(d__name='foo') | + Q(b__name='foo') | + Q(b__c__name='foo') + ) + qset = ModelA.objects.filter(q_obj) + self.assertEqual(list(qset), [self.a1]) + # We generate one INNER JOIN to D. The join is direct and not nullable + # so we can use INNER JOIN for it. However, we can NOT use INNER JOIN + # for the b->c join, as a->b is nullable. + self.assertEqual(str(qset.query).count('INNER JOIN'), 1) + + def test_isnull_filter_promotion(self): + qs = ModelA.objects.filter(Q(b__name__isnull=True)) + self.assertEqual(str(qs.query).count('LEFT OUTER'), 1) + self.assertEqual(list(qs), [self.a1]) + + qs = ModelA.objects.filter(~Q(b__name__isnull=True)) + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + self.assertEqual(list(qs), [self.a2]) + + qs = ModelA.objects.filter(~~Q(b__name__isnull=True)) + self.assertEqual(str(qs.query).count('LEFT OUTER'), 1) + self.assertEqual(list(qs), [self.a1]) + + qs = ModelA.objects.filter(Q(b__name__isnull=False)) + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + self.assertEqual(list(qs), [self.a2]) + + qs = ModelA.objects.filter(~Q(b__name__isnull=False)) + self.assertEqual(str(qs.query).count('LEFT OUTER'), 1) + self.assertEqual(list(qs), [self.a1]) + + qs = ModelA.objects.filter(~~Q(b__name__isnull=False)) + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + self.assertEqual(list(qs), [self.a2]) + + def test_null_join_demotion(self): + qs = ModelA.objects.filter(Q(b__name__isnull=False) & Q(b__name__isnull=True)) + self.assertIn(' INNER JOIN ', str(qs.query)) + qs = ModelA.objects.filter(Q(b__name__isnull=True) & Q(b__name__isnull=False)) + self.assertIn(' INNER JOIN ', str(qs.query)) + qs = ModelA.objects.filter(Q(b__name__isnull=False) | Q(b__name__isnull=True)) + self.assertIn(' LEFT OUTER JOIN ', str(qs.query)) + qs = ModelA.objects.filter(Q(b__name__isnull=True) | Q(b__name__isnull=False)) + self.assertIn(' LEFT OUTER JOIN ', str(qs.query)) + + def test_ticket_21366(self): + n = Note.objects.create(note='n', misc='m') + e = ExtraInfo.objects.create(info='info', note=n) + a = Author.objects.create(name='Author1', num=1, extra=e) + Ranking.objects.create(rank=1, author=a) + r1 = Report.objects.create(name='Foo', creator=a) + r2 = Report.objects.create(name='Bar') + Report.objects.create(name='Bar', creator=a) + qs = Report.objects.filter( + Q(creator__ranking__isnull=True) | + Q(creator__ranking__rank=1, name='Foo') + ) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 2) + self.assertEqual(str(qs.query).count(' JOIN '), 2) + self.assertSequenceEqual(qs.order_by('name'), [r2, r1]) + + def test_ticket_21748(self): + i1 = Identifier.objects.create(name='i1') + i2 = Identifier.objects.create(name='i2') + i3 = Identifier.objects.create(name='i3') + Program.objects.create(identifier=i1) + Channel.objects.create(identifier=i1) + Program.objects.create(identifier=i2) + self.assertSequenceEqual(Identifier.objects.filter(program=None, channel=None), [i3]) + self.assertSequenceEqual(Identifier.objects.exclude(program=None, channel=None).order_by('name'), [i1, i2]) + + def test_ticket_21748_double_negated_and(self): + i1 = Identifier.objects.create(name='i1') + i2 = Identifier.objects.create(name='i2') + Identifier.objects.create(name='i3') + p1 = Program.objects.create(identifier=i1) + c1 = Channel.objects.create(identifier=i1) + Program.objects.create(identifier=i2) + # Check the ~~Q() (or equivalently .exclude(~Q)) works like Q() for + # join promotion. + qs1_doubleneg = Identifier.objects.exclude(~Q(program__id=p1.id, channel__id=c1.id)).order_by('pk') + qs1_filter = Identifier.objects.filter(program__id=p1.id, channel__id=c1.id).order_by('pk') + self.assertQuerysetEqual(qs1_doubleneg, qs1_filter, lambda x: x) + self.assertEqual(str(qs1_filter.query).count('JOIN'), + str(qs1_doubleneg.query).count('JOIN')) + self.assertEqual(2, str(qs1_doubleneg.query).count('INNER JOIN')) + self.assertEqual(str(qs1_filter.query).count('INNER JOIN'), + str(qs1_doubleneg.query).count('INNER JOIN')) + + def test_ticket_21748_double_negated_or(self): + i1 = Identifier.objects.create(name='i1') + i2 = Identifier.objects.create(name='i2') + Identifier.objects.create(name='i3') + p1 = Program.objects.create(identifier=i1) + c1 = Channel.objects.create(identifier=i1) + p2 = Program.objects.create(identifier=i2) + # Test OR + doubleneg. The expected result is that channel is LOUTER + # joined, program INNER joined + qs1_filter = Identifier.objects.filter( + Q(program__id=p2.id, channel__id=c1.id) | Q(program__id=p1.id) + ).order_by('pk') + qs1_doubleneg = Identifier.objects.exclude( + ~Q(Q(program__id=p2.id, channel__id=c1.id) | Q(program__id=p1.id)) + ).order_by('pk') + self.assertQuerysetEqual(qs1_doubleneg, qs1_filter, lambda x: x) + self.assertEqual(str(qs1_filter.query).count('JOIN'), + str(qs1_doubleneg.query).count('JOIN')) + self.assertEqual(1, str(qs1_doubleneg.query).count('INNER JOIN')) + self.assertEqual(str(qs1_filter.query).count('INNER JOIN'), + str(qs1_doubleneg.query).count('INNER JOIN')) + + def test_ticket_21748_complex_filter(self): + i1 = Identifier.objects.create(name='i1') + i2 = Identifier.objects.create(name='i2') + Identifier.objects.create(name='i3') + p1 = Program.objects.create(identifier=i1) + c1 = Channel.objects.create(identifier=i1) + p2 = Program.objects.create(identifier=i2) + # Finally, a more complex case, one time in a way where each + # NOT is pushed to lowest level in the boolean tree, and + # another query where this isn't done. + qs1 = Identifier.objects.filter( + ~Q(~Q(program__id=p2.id, channel__id=c1.id) & Q(program__id=p1.id)) + ).order_by('pk') + qs2 = Identifier.objects.filter( + Q(Q(program__id=p2.id, channel__id=c1.id) | ~Q(program__id=p1.id)) + ).order_by('pk') + self.assertQuerysetEqual(qs1, qs2, lambda x: x) + self.assertEqual(str(qs1.query).count('JOIN'), + str(qs2.query).count('JOIN')) + self.assertEqual(0, str(qs1.query).count('INNER JOIN')) + self.assertEqual(str(qs1.query).count('INNER JOIN'), + str(qs2.query).count('INNER JOIN')) + + +class ReverseJoinTrimmingTest(TestCase): + def test_reverse_trimming(self): + # We don't accidentally trim reverse joins - we can't know if there is + # anything on the other side of the join, so trimming reverse joins + # can't be done, ever. + t = Tag.objects.create() + qs = Tag.objects.filter(annotation__tag=t.pk) + self.assertIn('INNER JOIN', str(qs.query)) + self.assertEqual(list(qs), []) + + +class JoinReuseTest(TestCase): + """ + The queries reuse joins sensibly (for example, direct joins + are always reused). + """ + def test_fk_reuse(self): + qs = Annotation.objects.filter(tag__name='foo').filter(tag__name='bar') + self.assertEqual(str(qs.query).count('JOIN'), 1) + + def test_fk_reuse_select_related(self): + qs = Annotation.objects.filter(tag__name='foo').select_related('tag') + self.assertEqual(str(qs.query).count('JOIN'), 1) + + def test_fk_reuse_annotation(self): + qs = Annotation.objects.filter(tag__name='foo').annotate(cnt=Count('tag__name')) + self.assertEqual(str(qs.query).count('JOIN'), 1) + + def test_fk_reuse_disjunction(self): + qs = Annotation.objects.filter(Q(tag__name='foo') | Q(tag__name='bar')) + self.assertEqual(str(qs.query).count('JOIN'), 1) + + def test_fk_reuse_order_by(self): + qs = Annotation.objects.filter(tag__name='foo').order_by('tag__name') + self.assertEqual(str(qs.query).count('JOIN'), 1) + + def test_revo2o_reuse(self): + qs = Detail.objects.filter(member__name='foo').filter(member__name='foo') + self.assertEqual(str(qs.query).count('JOIN'), 1) + + def test_revfk_noreuse(self): + qs = Author.objects.filter(report__name='r4').filter(report__name='r1') + self.assertEqual(str(qs.query).count('JOIN'), 2) + + def test_inverted_q_across_relations(self): + """ + When a trimmable join is specified in the query (here school__), the + ORM detects it and removes unnecessary joins. The set of reusable joins + are updated after trimming the query so that other lookups don't + consider that the outer query's filters are in effect for the subquery + (#26551). + """ + springfield_elementary = School.objects.create() + hogward = School.objects.create() + Student.objects.create(school=springfield_elementary) + hp = Student.objects.create(school=hogward) + Classroom.objects.create(school=hogward, name='Potion') + Classroom.objects.create(school=springfield_elementary, name='Main') + qs = Student.objects.filter( + ~(Q(school__classroom__name='Main') & Q(school__classroom__has_blackboard=None)) + ) + self.assertSequenceEqual(qs, [hp]) + + +class DisjunctionPromotionTests(TestCase): + def test_disjunction_promotion_select_related(self): + fk1 = FK1.objects.create(f1='f1', f2='f2') + basea = BaseA.objects.create(a=fk1) + qs = BaseA.objects.filter(Q(a=fk1) | Q(b=2)) + self.assertEqual(str(qs.query).count(' JOIN '), 0) + qs = qs.select_related('a', 'b') + self.assertEqual(str(qs.query).count(' INNER JOIN '), 0) + self.assertEqual(str(qs.query).count(' LEFT OUTER JOIN '), 2) + with self.assertNumQueries(1): + self.assertSequenceEqual(qs, [basea]) + self.assertEqual(qs[0].a, fk1) + self.assertIs(qs[0].b, None) + + def test_disjunction_promotion1(self): + # Pre-existing join, add two ORed filters to the same join, + # all joins can be INNER JOINS. + qs = BaseA.objects.filter(a__f1='foo') + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + qs = qs.filter(Q(b__f1='foo') | Q(b__f2='foo')) + self.assertEqual(str(qs.query).count('INNER JOIN'), 2) + # Reverse the order of AND and OR filters. + qs = BaseA.objects.filter(Q(b__f1='foo') | Q(b__f2='foo')) + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + qs = qs.filter(a__f1='foo') + self.assertEqual(str(qs.query).count('INNER JOIN'), 2) + + def test_disjunction_promotion2(self): + qs = BaseA.objects.filter(a__f1='foo') + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + # Now we have two different joins in an ORed condition, these + # must be OUTER joins. The pre-existing join should remain INNER. + qs = qs.filter(Q(b__f1='foo') | Q(c__f2='foo')) + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 2) + # Reverse case. + qs = BaseA.objects.filter(Q(b__f1='foo') | Q(c__f2='foo')) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 2) + qs = qs.filter(a__f1='foo') + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 2) + + def test_disjunction_promotion3(self): + qs = BaseA.objects.filter(a__f2='bar') + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + # The ANDed a__f2 filter allows us to use keep using INNER JOIN + # even inside the ORed case. If the join to a__ returns nothing, + # the ANDed filter for a__f2 can't be true. + qs = qs.filter(Q(a__f1='foo') | Q(b__f2='foo')) + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 1) + + def test_disjunction_promotion3_demote(self): + # This one needs demotion logic: the first filter causes a to be + # outer joined, the second filter makes it inner join again. + qs = BaseA.objects.filter( + Q(a__f1='foo') | Q(b__f2='foo')).filter(a__f2='bar') + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 1) + + def test_disjunction_promotion4_demote(self): + qs = BaseA.objects.filter(Q(a=1) | Q(a=2)) + self.assertEqual(str(qs.query).count('JOIN'), 0) + # Demote needed for the "a" join. It is marked as outer join by + # above filter (even if it is trimmed away). + qs = qs.filter(a__f1='foo') + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + + def test_disjunction_promotion4(self): + qs = BaseA.objects.filter(a__f1='foo') + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + qs = qs.filter(Q(a=1) | Q(a=2)) + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + + def test_disjunction_promotion5_demote(self): + qs = BaseA.objects.filter(Q(a=1) | Q(a=2)) + # Note that the above filters on a force the join to an + # inner join even if it is trimmed. + self.assertEqual(str(qs.query).count('JOIN'), 0) + qs = qs.filter(Q(a__f1='foo') | Q(b__f1='foo')) + # So, now the a__f1 join doesn't need promotion. + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + # But b__f1 does. + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 1) + qs = BaseA.objects.filter(Q(a__f1='foo') | Q(b__f1='foo')) + # Now the join to a is created as LOUTER + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 2) + qs = qs.filter(Q(a=1) | Q(a=2)) + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 1) + + def test_disjunction_promotion6(self): + qs = BaseA.objects.filter(Q(a=1) | Q(a=2)) + self.assertEqual(str(qs.query).count('JOIN'), 0) + qs = BaseA.objects.filter(Q(a__f1='foo') & Q(b__f1='foo')) + self.assertEqual(str(qs.query).count('INNER JOIN'), 2) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 0) + + qs = BaseA.objects.filter(Q(a__f1='foo') & Q(b__f1='foo')) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 0) + self.assertEqual(str(qs.query).count('INNER JOIN'), 2) + qs = qs.filter(Q(a=1) | Q(a=2)) + self.assertEqual(str(qs.query).count('INNER JOIN'), 2) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 0) + + def test_disjunction_promotion7(self): + qs = BaseA.objects.filter(Q(a=1) | Q(a=2)) + self.assertEqual(str(qs.query).count('JOIN'), 0) + qs = BaseA.objects.filter(Q(a__f1='foo') | (Q(b__f1='foo') & Q(a__f1='bar'))) + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 1) + qs = BaseA.objects.filter( + (Q(a__f1='foo') | Q(b__f1='foo')) & (Q(a__f1='bar') | Q(c__f1='foo')) + ) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 3) + self.assertEqual(str(qs.query).count('INNER JOIN'), 0) + qs = BaseA.objects.filter( + (Q(a__f1='foo') | (Q(a__f1='bar')) & (Q(b__f1='bar') | Q(c__f1='foo'))) + ) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 2) + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + + def test_disjunction_promotion_fexpression(self): + qs = BaseA.objects.filter(Q(a__f1=F('b__f1')) | Q(b__f1='foo')) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 1) + self.assertEqual(str(qs.query).count('INNER JOIN'), 1) + qs = BaseA.objects.filter(Q(a__f1=F('c__f1')) | Q(b__f1='foo')) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 3) + qs = BaseA.objects.filter(Q(a__f1=F('b__f1')) | Q(a__f2=F('b__f2')) | Q(c__f1='foo')) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 3) + qs = BaseA.objects.filter(Q(a__f1=F('c__f1')) | (Q(pk=1) & Q(pk=2))) + self.assertEqual(str(qs.query).count('LEFT OUTER JOIN'), 2) + self.assertEqual(str(qs.query).count('INNER JOIN'), 0) + + +class ManyToManyExcludeTest(TestCase): + def test_exclude_many_to_many(self): + Identifier.objects.create(name='extra') + program = Program.objects.create(identifier=Identifier.objects.create(name='program')) + channel = Channel.objects.create(identifier=Identifier.objects.create(name='channel')) + channel.programs.add(program) + + # channel contains 'program1', so all Identifiers except that one + # should be returned + self.assertQuerysetEqual( + Identifier.objects.exclude(program__channel=channel).order_by('name'), + ['', ''] + ) + self.assertQuerysetEqual( + Identifier.objects.exclude(program__channel=None).order_by('name'), + [''] + ) + + def test_ticket_12823(self): + pg3 = Page.objects.create(text='pg3') + pg2 = Page.objects.create(text='pg2') + pg1 = Page.objects.create(text='pg1') + pa1 = Paragraph.objects.create(text='pa1') + pa1.page.set([pg1, pg2]) + pa2 = Paragraph.objects.create(text='pa2') + pa2.page.set([pg2, pg3]) + pa3 = Paragraph.objects.create(text='pa3') + ch1 = Chapter.objects.create(title='ch1', paragraph=pa1) + ch2 = Chapter.objects.create(title='ch2', paragraph=pa2) + ch3 = Chapter.objects.create(title='ch3', paragraph=pa3) + b1 = Book.objects.create(title='b1', chapter=ch1) + b2 = Book.objects.create(title='b2', chapter=ch2) + b3 = Book.objects.create(title='b3', chapter=ch3) + q = Book.objects.exclude(chapter__paragraph__page__text='pg1') + self.assertNotIn('IS NOT NULL', str(q.query)) + self.assertEqual(len(q), 2) + self.assertNotIn(b1, q) + self.assertIn(b2, q) + self.assertIn(b3, q) + + +class RelabelCloneTest(TestCase): + def test_ticket_19964(self): + my1 = MyObject.objects.create(data='foo') + my1.parent = my1 + my1.save() + my2 = MyObject.objects.create(data='bar', parent=my1) + parents = MyObject.objects.filter(parent=F('id')) + children = MyObject.objects.filter(parent__in=parents).exclude(parent=F('id')) + self.assertEqual(list(parents), [my1]) + # Evaluating the children query (which has parents as part of it) does + # not change results for the parents query. + self.assertEqual(list(children), [my2]) + self.assertEqual(list(parents), [my1]) + + +class Ticket20101Tests(TestCase): + def test_ticket_20101(self): + """ + Tests QuerySet ORed combining in exclude subquery case. + """ + t = Tag.objects.create(name='foo') + a1 = Annotation.objects.create(tag=t, name='a1') + a2 = Annotation.objects.create(tag=t, name='a2') + a3 = Annotation.objects.create(tag=t, name='a3') + n = Note.objects.create(note='foo', misc='bar') + qs1 = Note.objects.exclude(annotation__in=[a1, a2]) + qs2 = Note.objects.filter(annotation__in=[a3]) + self.assertIn(n, qs1) + self.assertNotIn(n, qs2) + self.assertIn(n, (qs1 | qs2)) + + +class EmptyStringPromotionTests(TestCase): + def test_empty_string_promotion(self): + qs = RelatedObject.objects.filter(single__name='') + if connection.features.interprets_empty_strings_as_nulls: + self.assertIn('LEFT OUTER JOIN', str(qs.query)) + else: + self.assertNotIn('LEFT OUTER JOIN', str(qs.query)) + + +class ValuesSubqueryTests(TestCase): + def test_values_in_subquery(self): + # If a values() queryset is used, then the given values + # will be used instead of forcing use of the relation's field. + o1 = Order.objects.create(id=-2) + o2 = Order.objects.create(id=-1) + oi1 = OrderItem.objects.create(order=o1, status=0) + oi1.status = oi1.pk + oi1.save() + OrderItem.objects.create(order=o2, status=0) + + # The query below should match o1 as it has related order_item + # with id == status. + self.assertSequenceEqual(Order.objects.filter(items__in=OrderItem.objects.values_list('status')), [o1]) + + +class DoubleInSubqueryTests(TestCase): + def test_double_subquery_in(self): + lfa1 = LeafA.objects.create(data='foo') + lfa2 = LeafA.objects.create(data='bar') + lfb1 = LeafB.objects.create(data='lfb1') + lfb2 = LeafB.objects.create(data='lfb2') + Join.objects.create(a=lfa1, b=lfb1) + Join.objects.create(a=lfa2, b=lfb2) + leaf_as = LeafA.objects.filter(data='foo').values_list('pk', flat=True) + joins = Join.objects.filter(a__in=leaf_as).values_list('b__id', flat=True) + qs = LeafB.objects.filter(pk__in=joins) + self.assertSequenceEqual(qs, [lfb1]) + + +class Ticket18785Tests(TestCase): + def test_ticket_18785(self): + # Test join trimming from ticket18785 + qs = Item.objects.exclude( + note__isnull=False + ).filter( + name='something', creator__extra__isnull=True + ).order_by() + self.assertEqual(1, str(qs.query).count('INNER JOIN')) + self.assertEqual(0, str(qs.query).count('OUTER JOIN')) + + +class Ticket20788Tests(TestCase): + def test_ticket_20788(self): + Paragraph.objects.create() + paragraph = Paragraph.objects.create() + page = paragraph.page.create() + chapter = Chapter.objects.create(paragraph=paragraph) + Book.objects.create(chapter=chapter) + + paragraph2 = Paragraph.objects.create() + Page.objects.create() + chapter2 = Chapter.objects.create(paragraph=paragraph2) + book2 = Book.objects.create(chapter=chapter2) + + sentences_not_in_pub = Book.objects.exclude(chapter__paragraph__page=page) + self.assertSequenceEqual(sentences_not_in_pub, [book2]) + + +class Ticket12807Tests(TestCase): + def test_ticket_12807(self): + p1 = Paragraph.objects.create() + p2 = Paragraph.objects.create() + # The ORed condition below should have no effect on the query - the + # ~Q(pk__in=[]) will always be True. + qs = Paragraph.objects.filter((Q(pk=p2.pk) | ~Q(pk__in=[])) & Q(pk=p1.pk)) + self.assertSequenceEqual(qs, [p1]) + + +class RelatedLookupTypeTests(TestCase): + error = 'Cannot query "%s": Must be "%s" instance.' + + @classmethod + def setUpTestData(cls): + cls.oa = ObjectA.objects.create(name="oa") + cls.poa = ProxyObjectA.objects.get(name="oa") + cls.coa = ChildObjectA.objects.create(name="coa") + cls.wrong_type = Order.objects.create(id=cls.oa.pk) + cls.ob = ObjectB.objects.create(name="ob", objecta=cls.oa, num=1) + ProxyObjectB.objects.create(name="pob", objecta=cls.oa, num=2) + cls.pob = ProxyObjectB.objects.all() + ObjectC.objects.create(childobjecta=cls.coa) + + def test_wrong_type_lookup(self): + """ + A ValueError is raised when the incorrect object type is passed to a + query lookup. + """ + # Passing incorrect object type + with self.assertRaisesMessage(ValueError, self.error % (self.wrong_type, ObjectA._meta.object_name)): + ObjectB.objects.get(objecta=self.wrong_type) + + with self.assertRaisesMessage(ValueError, self.error % (self.wrong_type, ObjectA._meta.object_name)): + ObjectB.objects.filter(objecta__in=[self.wrong_type]) + + with self.assertRaisesMessage(ValueError, self.error % (self.wrong_type, ObjectA._meta.object_name)): + ObjectB.objects.filter(objecta=self.wrong_type) + + with self.assertRaisesMessage(ValueError, self.error % (self.wrong_type, ObjectB._meta.object_name)): + ObjectA.objects.filter(objectb__in=[self.wrong_type, self.ob]) + + # Passing an object of the class on which query is done. + with self.assertRaisesMessage(ValueError, self.error % (self.ob, ObjectA._meta.object_name)): + ObjectB.objects.filter(objecta__in=[self.poa, self.ob]) + + with self.assertRaisesMessage(ValueError, self.error % (self.ob, ChildObjectA._meta.object_name)): + ObjectC.objects.exclude(childobjecta__in=[self.coa, self.ob]) + + def test_wrong_backward_lookup(self): + """ + A ValueError is raised when the incorrect object type is passed to a + query lookup for backward relations. + """ + with self.assertRaisesMessage(ValueError, self.error % (self.oa, ObjectB._meta.object_name)): + ObjectA.objects.filter(objectb__in=[self.oa, self.ob]) + + with self.assertRaisesMessage(ValueError, self.error % (self.oa, ObjectB._meta.object_name)): + ObjectA.objects.exclude(objectb=self.oa) + + with self.assertRaisesMessage(ValueError, self.error % (self.wrong_type, ObjectB._meta.object_name)): + ObjectA.objects.get(objectb=self.wrong_type) + + def test_correct_lookup(self): + """ + When passing proxy model objects, child objects, or parent objects, + lookups work fine. + """ + out_a = [''] + out_b = ['', ''] + out_c = [''] + + # proxy model objects + self.assertQuerysetEqual(ObjectB.objects.filter(objecta=self.poa).order_by('name'), out_b) + self.assertQuerysetEqual(ObjectA.objects.filter(objectb__in=self.pob).order_by('pk'), out_a * 2) + + # child objects + self.assertQuerysetEqual(ObjectB.objects.filter(objecta__in=[self.coa]), []) + self.assertQuerysetEqual(ObjectB.objects.filter(objecta__in=[self.poa, self.coa]).order_by('name'), out_b) + self.assertQuerysetEqual( + ObjectB.objects.filter(objecta__in=iter([self.poa, self.coa])).order_by('name'), + out_b + ) + + # parent objects + self.assertQuerysetEqual(ObjectC.objects.exclude(childobjecta=self.oa), out_c) + + # QuerySet related object type checking shouldn't issue queries + # (the querysets aren't evaluated here, hence zero queries) (#23266). + with self.assertNumQueries(0): + ObjectB.objects.filter(objecta__in=ObjectA.objects.all()) + + def test_values_queryset_lookup(self): + """ + #23396 - Ensure ValueQuerySets are not checked for compatibility with the lookup field + """ + # Make sure the num and objecta field values match. + ob = ObjectB.objects.get(name='ob') + ob.num = ob.objecta.pk + ob.save() + pob = ObjectB.objects.get(name='pob') + pob.num = pob.objecta.pk + pob.save() + self.assertQuerysetEqual(ObjectB.objects.filter( + objecta__in=ObjectB.objects.all().values_list('num') + ).order_by('pk'), ['', '']) + + +class Ticket14056Tests(TestCase): + def test_ticket_14056(self): + s1 = SharedConnection.objects.create(data='s1') + s2 = SharedConnection.objects.create(data='s2') + s3 = SharedConnection.objects.create(data='s3') + PointerA.objects.create(connection=s2) + expected_ordering = ( + [s1, s3, s2] if connection.features.nulls_order_largest + else [s2, s1, s3] + ) + self.assertSequenceEqual(SharedConnection.objects.order_by('-pointera__connection', 'pk'), expected_ordering) + + +class Ticket20955Tests(TestCase): + def test_ticket_20955(self): + jack = Staff.objects.create(name='jackstaff') + jackstaff = StaffUser.objects.create(staff=jack) + jill = Staff.objects.create(name='jillstaff') + jillstaff = StaffUser.objects.create(staff=jill) + task = Task.objects.create(creator=jackstaff, owner=jillstaff, title="task") + task_get = Task.objects.get(pk=task.pk) + # Load data so that assertNumQueries doesn't complain about the get + # version's queries. + task_get.creator.staffuser.staff + task_get.owner.staffuser.staff + qs = Task.objects.select_related( + 'creator__staffuser__staff', 'owner__staffuser__staff') + self.assertEqual(str(qs.query).count(' JOIN '), 6) + task_select_related = qs.get(pk=task.pk) + with self.assertNumQueries(0): + self.assertEqual(task_select_related.creator.staffuser.staff, + task_get.creator.staffuser.staff) + self.assertEqual(task_select_related.owner.staffuser.staff, + task_get.owner.staffuser.staff) + + +class Ticket21203Tests(TestCase): + def test_ticket_21203(self): + p = Ticket21203Parent.objects.create(parent_bool=True) + c = Ticket21203Child.objects.create(parent=p) + qs = Ticket21203Child.objects.select_related('parent').defer('parent__created') + self.assertSequenceEqual(qs, [c]) + self.assertIs(qs[0].parent.parent_bool, True) + + +class ValuesJoinPromotionTests(TestCase): + def test_values_no_promotion_for_existing(self): + qs = Node.objects.filter(parent__parent__isnull=False) + self.assertIn(' INNER JOIN ', str(qs.query)) + qs = qs.values('parent__parent__id') + self.assertIn(' INNER JOIN ', str(qs.query)) + # Make sure there is a left outer join without the filter. + qs = Node.objects.values('parent__parent__id') + self.assertIn(' LEFT OUTER JOIN ', str(qs.query)) + + def test_non_nullable_fk_not_promoted(self): + qs = ObjectB.objects.values('objecta__name') + self.assertIn(' INNER JOIN ', str(qs.query)) + + def test_ticket_21376(self): + a = ObjectA.objects.create() + ObjectC.objects.create(objecta=a) + qs = ObjectC.objects.filter( + Q(objecta=a) | Q(objectb__objecta=a), + ) + qs = qs.filter( + Q(objectb=1) | Q(objecta=a), + ) + self.assertEqual(qs.count(), 1) + tblname = connection.ops.quote_name(ObjectB._meta.db_table) + self.assertIn(' LEFT OUTER JOIN %s' % tblname, str(qs.query)) + + +class ForeignKeyToBaseExcludeTests(TestCase): + def test_ticket_21787(self): + sc1 = SpecialCategory.objects.create(special_name='sc1', name='sc1') + sc2 = SpecialCategory.objects.create(special_name='sc2', name='sc2') + sc3 = SpecialCategory.objects.create(special_name='sc3', name='sc3') + c1 = CategoryItem.objects.create(category=sc1) + CategoryItem.objects.create(category=sc2) + self.assertSequenceEqual(SpecialCategory.objects.exclude(categoryitem__id=c1.pk).order_by('name'), [sc2, sc3]) + self.assertSequenceEqual(SpecialCategory.objects.filter(categoryitem__id=c1.pk), [sc1]) + + +class ReverseM2MCustomPkTests(TestCase): + def test_ticket_21879(self): + cpt1 = CustomPkTag.objects.create(id='cpt1', tag='cpt1') + cp1 = CustomPk.objects.create(name='cp1', extra='extra') + cp1.custompktag_set.add(cpt1) + self.assertSequenceEqual(CustomPk.objects.filter(custompktag=cpt1), [cp1]) + self.assertSequenceEqual(CustomPkTag.objects.filter(custom_pk=cp1), [cpt1]) + + +class Ticket22429Tests(TestCase): + def test_ticket_22429(self): + sc1 = School.objects.create() + st1 = Student.objects.create(school=sc1) + + sc2 = School.objects.create() + st2 = Student.objects.create(school=sc2) + + cr = Classroom.objects.create(school=sc1) + cr.students.add(st1) + + queryset = Student.objects.filter(~Q(classroom__school=F('school'))) + self.assertSequenceEqual(queryset, [st2]) + + +class Ticket23605Tests(TestCase): + def test_ticket_23605(self): + # Test filtering on a complicated q-object from ticket's report. + # The query structure is such that we have multiple nested subqueries. + # The original problem was that the inner queries weren't relabeled + # correctly. + # See also #24090. + a1 = Ticket23605A.objects.create() + a2 = Ticket23605A.objects.create() + c1 = Ticket23605C.objects.create(field_c0=10000.0) + Ticket23605B.objects.create( + field_b0=10000.0, field_b1=True, + modelc_fk=c1, modela_fk=a1) + complex_q = Q(pk__in=Ticket23605A.objects.filter( + Q( + # True for a1 as field_b0 = 10000, field_c0=10000 + # False for a2 as no ticket23605b found + ticket23605b__field_b0__gte=1000000 / + F("ticket23605b__modelc_fk__field_c0") + ) & + # True for a1 (field_b1=True) + Q(ticket23605b__field_b1=True) & ~Q(ticket23605b__pk__in=Ticket23605B.objects.filter( + ~( + # Same filters as above commented filters, but + # double-negated (one for Q() above, one for + # parentheses). So, again a1 match, a2 not. + Q(field_b1=True) & + Q(field_b0__gte=1000000 / F("modelc_fk__field_c0")) + ) + ))).filter(ticket23605b__field_b1=True)) + qs1 = Ticket23605A.objects.filter(complex_q) + self.assertSequenceEqual(qs1, [a1]) + qs2 = Ticket23605A.objects.exclude(complex_q) + self.assertSequenceEqual(qs2, [a2]) + + +class TestTicket24279(TestCase): + def test_ticket_24278(self): + School.objects.create() + qs = School.objects.filter(Q(pk__in=()) | Q()) + self.assertQuerysetEqual(qs, []) + + +class TestInvalidValuesRelation(TestCase): + def test_invalid_values(self): + msg = "invalid literal for int() with base 10: 'abc'" + with self.assertRaisesMessage(ValueError, msg): + Annotation.objects.filter(tag='abc') + with self.assertRaisesMessage(ValueError, msg): + Annotation.objects.filter(tag__in=[123, 'abc']) + + +class TestTicket24605(TestCase): + def test_ticket_24605(self): + """ + Subquery table names should be quoted. + """ + i1 = Individual.objects.create(alive=True) + RelatedIndividual.objects.create(related=i1) + i2 = Individual.objects.create(alive=False) + RelatedIndividual.objects.create(related=i2) + i3 = Individual.objects.create(alive=True) + i4 = Individual.objects.create(alive=False) + + self.assertSequenceEqual(Individual.objects.filter(Q(alive=False), Q(related_individual__isnull=True)), [i4]) + self.assertSequenceEqual( + Individual.objects.exclude(Q(alive=False), Q(related_individual__isnull=True)).order_by('pk'), + [i1, i2, i3] + ) + + +class Ticket23622Tests(TestCase): + @skipUnlessDBFeature('can_distinct_on_fields') + def test_ticket_23622(self): + """ + Make sure __pk__in and __in work the same for related fields when + using a distinct on subquery. + """ + a1 = Ticket23605A.objects.create() + a2 = Ticket23605A.objects.create() + c1 = Ticket23605C.objects.create(field_c0=0.0) + Ticket23605B.objects.create( + modela_fk=a1, field_b0=123, + field_b1=True, + modelc_fk=c1, + ) + Ticket23605B.objects.create( + modela_fk=a1, field_b0=23, + field_b1=True, + modelc_fk=c1, + ) + Ticket23605B.objects.create( + modela_fk=a1, field_b0=234, + field_b1=True, + modelc_fk=c1, + ) + Ticket23605B.objects.create( + modela_fk=a1, field_b0=12, + field_b1=True, + modelc_fk=c1, + ) + Ticket23605B.objects.create( + modela_fk=a2, field_b0=567, + field_b1=True, + modelc_fk=c1, + ) + Ticket23605B.objects.create( + modela_fk=a2, field_b0=76, + field_b1=True, + modelc_fk=c1, + ) + Ticket23605B.objects.create( + modela_fk=a2, field_b0=7, + field_b1=True, + modelc_fk=c1, + ) + Ticket23605B.objects.create( + modela_fk=a2, field_b0=56, + field_b1=True, + modelc_fk=c1, + ) + qx = ( + Q(ticket23605b__pk__in=Ticket23605B.objects.order_by('modela_fk', '-field_b1').distinct('modela_fk')) & + Q(ticket23605b__field_b0__gte=300) + ) + qy = ( + Q(ticket23605b__in=Ticket23605B.objects.order_by('modela_fk', '-field_b1').distinct('modela_fk')) & + Q(ticket23605b__field_b0__gte=300) + ) + self.assertEqual( + set(Ticket23605A.objects.filter(qx).values_list('pk', flat=True)), + set(Ticket23605A.objects.filter(qy).values_list('pk', flat=True)) + ) + self.assertSequenceEqual(Ticket23605A.objects.filter(qx), [a2]) diff --git a/tests/raw_query/__init__.py b/tests/raw_query/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/raw_query/models.py b/tests/raw_query/models.py new file mode 100644 index 00000000..6e996322 --- /dev/null +++ b/tests/raw_query/models.py @@ -0,0 +1,40 @@ +from django.db import models + + +class Author(models.Model): + first_name = models.CharField(max_length=255) + last_name = models.CharField(max_length=255) + dob = models.DateField() + + def __init__(self, *args, **kwargs): + super(Author, self).__init__(*args, **kwargs) + # Protect against annotations being passed to __init__ -- + # this'll make the test suite get angry if annotations aren't + # treated differently than fields. + for k in kwargs: + assert k in [f.attname for f in self._meta.fields], \ + "Author.__init__ got an unexpected parameter: %s" % k + + +class Book(models.Model): + title = models.CharField(max_length=255) + author = models.ForeignKey(Author, models.CASCADE) + paperback = models.BooleanField(default=False) + opening_line = models.TextField() + + +class BookFkAsPk(models.Model): + book = models.ForeignKey(Book, models.CASCADE, primary_key=True, db_column="not_the_default") + + +class Coffee(models.Model): + brand = models.CharField(max_length=255, db_column="name") + price = models.DecimalField(max_digits=10, decimal_places=2, default=0) + + +class Reviewer(models.Model): + reviewed = models.ManyToManyField(Book) + + +class FriendlyAuthor(Author): + pass diff --git a/tests/raw_query/tests.py b/tests/raw_query/tests.py new file mode 100644 index 00000000..360dd31d --- /dev/null +++ b/tests/raw_query/tests.py @@ -0,0 +1,311 @@ +from __future__ import unicode_literals + +from datetime import date +from decimal import Decimal + +from django.db.models.query import RawQuerySet +from django.db.models.query_utils import InvalidQuery +from django.test import TestCase, skipUnlessDBFeature + +from .models import Author, Book, BookFkAsPk, Coffee, FriendlyAuthor, Reviewer + + +class RawQueryTests(TestCase): + + @classmethod + def setUpTestData(cls): + cls.a1 = Author.objects.create(first_name='Joe', last_name='Smith', dob=date(1950, 9, 20)) + cls.a2 = Author.objects.create(first_name='Jill', last_name='Doe', dob=date(1920, 4, 2)) + cls.a3 = Author.objects.create(first_name='Bob', last_name='Smith', dob=date(1986, 1, 25)) + cls.a4 = Author.objects.create(first_name='Bill', last_name='Jones', dob=date(1932, 5, 10)) + cls.b1 = Book.objects.create( + title='The awesome book', author=cls.a1, paperback=False, + opening_line='It was a bright cold day in April and the clocks were striking thirteen.', + ) + cls.b2 = Book.objects.create( + title='The horrible book', author=cls.a1, paperback=True, + opening_line=( + 'On an evening in the latter part of May a middle-aged man ' + 'was walking homeward from Shaston to the village of Marlott, ' + 'in the adjoining Vale of Blakemore, or Blackmoor.' + ), + ) + cls.b3 = Book.objects.create( + title='Another awesome book', author=cls.a1, paperback=False, + opening_line='A squat grey building of only thirty-four stories.', + ) + cls.b4 = Book.objects.create( + title='Some other book', author=cls.a3, paperback=True, + opening_line='It was the day my grandmother exploded.', + ) + cls.c1 = Coffee.objects.create(brand='dunkin doughnuts') + cls.c2 = Coffee.objects.create(brand='starbucks') + cls.r1 = Reviewer.objects.create() + cls.r2 = Reviewer.objects.create() + cls.r1.reviewed.add(cls.b2, cls.b3, cls.b4) + + def assertSuccessfulRawQuery(self, model, query, expected_results, + expected_annotations=(), params=[], translations=None): + """ + Execute the passed query against the passed model and check the output + """ + results = list(model.objects.raw(query, params=params, translations=translations)) + self.assertProcessed(model, results, expected_results, expected_annotations) + self.assertAnnotations(results, expected_annotations) + + def assertProcessed(self, model, results, orig, expected_annotations=()): + """ + Compare the results of a raw query against expected results + """ + self.assertEqual(len(results), len(orig)) + for index, item in enumerate(results): + orig_item = orig[index] + for annotation in expected_annotations: + setattr(orig_item, *annotation) + + for field in model._meta.fields: + # All values on the model are equal + self.assertEqual( + getattr(item, field.attname), + getattr(orig_item, field.attname) + ) + # This includes checking that they are the same type + self.assertEqual( + type(getattr(item, field.attname)), + type(getattr(orig_item, field.attname)) + ) + + def assertNoAnnotations(self, results): + """ + The results of a raw query contain no annotations + """ + self.assertAnnotations(results, ()) + + def assertAnnotations(self, results, expected_annotations): + """ + The passed raw query results contain the expected annotations + """ + if expected_annotations: + for index, result in enumerate(results): + annotation, value = expected_annotations[index] + self.assertTrue(hasattr(result, annotation)) + self.assertEqual(getattr(result, annotation), value) + + def test_rawqueryset_repr(self): + queryset = RawQuerySet(raw_query='SELECT * FROM raw_query_author') + self.assertEqual(repr(queryset), '') + self.assertEqual(repr(queryset.query), '') + + def test_simple_raw_query(self): + """ + Basic test of raw query with a simple database query + """ + query = "SELECT * FROM raw_query_author" + authors = Author.objects.all() + self.assertSuccessfulRawQuery(Author, query, authors) + + def test_raw_query_lazy(self): + """ + Raw queries are lazy: they aren't actually executed until they're + iterated over. + """ + q = Author.objects.raw('SELECT * FROM raw_query_author') + self.assertIsNone(q.query.cursor) + list(q) + self.assertIsNotNone(q.query.cursor) + + def test_FK_raw_query(self): + """ + Test of a simple raw query against a model containing a foreign key + """ + query = "SELECT * FROM raw_query_book" + books = Book.objects.all() + self.assertSuccessfulRawQuery(Book, query, books) + + def test_db_column_handler(self): + """ + Test of a simple raw query against a model containing a field with + db_column defined. + """ + query = "SELECT * FROM raw_query_coffee" + coffees = Coffee.objects.all() + self.assertSuccessfulRawQuery(Coffee, query, coffees) + + def test_order_handler(self): + """ + Test of raw raw query's tolerance for columns being returned in any + order + """ + selects = ( + ('dob, last_name, first_name, id'), + ('last_name, dob, first_name, id'), + ('first_name, last_name, dob, id'), + ) + + for select in selects: + query = "SELECT %s FROM raw_query_author" % select + authors = Author.objects.all() + self.assertSuccessfulRawQuery(Author, query, authors) + + def test_translations(self): + """ + Test of raw query's optional ability to translate unexpected result + column names to specific model fields + """ + query = "SELECT first_name AS first, last_name AS last, dob, id FROM raw_query_author" + translations = {'first': 'first_name', 'last': 'last_name'} + authors = Author.objects.all() + self.assertSuccessfulRawQuery(Author, query, authors, translations=translations) + + def test_params(self): + """ + Test passing optional query parameters + """ + query = "SELECT * FROM raw_query_author WHERE first_name = %s" + author = Author.objects.all()[2] + params = [author.first_name] + qset = Author.objects.raw(query, params=params) + results = list(qset) + self.assertProcessed(Author, results, [author]) + self.assertNoAnnotations(results) + self.assertEqual(len(results), 1) + self.assertIsInstance(repr(qset), str) + + @skipUnlessDBFeature('supports_paramstyle_pyformat') + def test_pyformat_params(self): + """ + Test passing optional query parameters + """ + query = "SELECT * FROM raw_query_author WHERE first_name = %(first)s" + author = Author.objects.all()[2] + params = {'first': author.first_name} + qset = Author.objects.raw(query, params=params) + results = list(qset) + self.assertProcessed(Author, results, [author]) + self.assertNoAnnotations(results) + self.assertEqual(len(results), 1) + self.assertIsInstance(repr(qset), str) + + def test_query_representation(self): + """ + Test representation of raw query with parameters + """ + query = "SELECT * FROM raw_query_author WHERE last_name = %(last)s" + qset = Author.objects.raw(query, {'last': 'foo'}) + self.assertEqual(repr(qset), "") + self.assertEqual(repr(qset.query), "") + + query = "SELECT * FROM raw_query_author WHERE last_name = %s" + qset = Author.objects.raw(query, {'foo'}) + self.assertEqual(repr(qset), "") + self.assertEqual(repr(qset.query), "") + + def test_many_to_many(self): + """ + Test of a simple raw query against a model containing a m2m field + """ + query = "SELECT * FROM raw_query_reviewer" + reviewers = Reviewer.objects.all() + self.assertSuccessfulRawQuery(Reviewer, query, reviewers) + + def test_extra_conversions(self): + """ + Test to insure that extra translations are ignored. + """ + query = "SELECT * FROM raw_query_author" + translations = {'something': 'else'} + authors = Author.objects.all() + self.assertSuccessfulRawQuery(Author, query, authors, translations=translations) + + def test_missing_fields(self): + query = "SELECT id, first_name, dob FROM raw_query_author" + for author in Author.objects.raw(query): + self.assertIsNotNone(author.first_name) + # last_name isn't given, but it will be retrieved on demand + self.assertIsNotNone(author.last_name) + + def test_missing_fields_without_PK(self): + query = "SELECT first_name, dob FROM raw_query_author" + with self.assertRaisesMessage(InvalidQuery, 'Raw query must include the primary key'): + list(Author.objects.raw(query)) + + def test_annotations(self): + query = ( + "SELECT a.*, count(b.id) as book_count " + "FROM raw_query_author a " + "LEFT JOIN raw_query_book b ON a.id = b.author_id " + "GROUP BY a.id, a.first_name, a.last_name, a.dob ORDER BY a.id" + ) + expected_annotations = ( + ('book_count', 3), + ('book_count', 0), + ('book_count', 1), + ('book_count', 0), + ) + authors = Author.objects.all() + self.assertSuccessfulRawQuery(Author, query, authors, expected_annotations) + + def test_white_space_query(self): + query = " SELECT * FROM raw_query_author" + authors = Author.objects.all() + self.assertSuccessfulRawQuery(Author, query, authors) + + def test_multiple_iterations(self): + query = "SELECT * FROM raw_query_author" + normal_authors = Author.objects.all() + raw_authors = Author.objects.raw(query) + + # First Iteration + first_iterations = 0 + for index, raw_author in enumerate(raw_authors): + self.assertEqual(normal_authors[index], raw_author) + first_iterations += 1 + + # Second Iteration + second_iterations = 0 + for index, raw_author in enumerate(raw_authors): + self.assertEqual(normal_authors[index], raw_author) + second_iterations += 1 + + self.assertEqual(first_iterations, second_iterations) + + def test_get_item(self): + # Indexing on RawQuerySets + query = "SELECT * FROM raw_query_author ORDER BY id ASC" + third_author = Author.objects.raw(query)[2] + self.assertEqual(third_author.first_name, 'Bob') + + first_two = Author.objects.raw(query)[0:2] + self.assertEqual(len(first_two), 2) + + with self.assertRaises(TypeError): + Author.objects.raw(query)['test'] + + def test_inheritance(self): + f = FriendlyAuthor.objects.create(first_name="Wesley", last_name="Chun", dob=date(1962, 10, 28)) + query = "SELECT * FROM raw_query_friendlyauthor" + self.assertEqual( + [o.pk for o in FriendlyAuthor.objects.raw(query)], [f.pk] + ) + + def test_query_count(self): + self.assertNumQueries(1, list, Author.objects.raw("SELECT * FROM raw_query_author")) + + def test_subquery_in_raw_sql(self): + list(Book.objects.raw('SELECT id FROM (SELECT * FROM raw_query_book WHERE paperback IS NOT NULL) sq')) + + def test_db_column_name_is_used_in_raw_query(self): + """ + Regression test that ensures the `column` attribute on the field is + used to generate the list of fields included in the query, as opposed + to the `attname`. This is important when the primary key is a + ForeignKey field because `attname` and `column` are not necessarily the + same. + """ + b = BookFkAsPk.objects.create(book=self.b1) + self.assertEqual(list(BookFkAsPk.objects.raw('SELECT not_the_default FROM raw_query_bookfkaspk')), [b]) + + def test_decimal_parameter(self): + c = Coffee.objects.create(brand='starbucks', price=20.5) + qs = Coffee.objects.raw("SELECT * FROM raw_query_coffee WHERE price >= %s", params=[Decimal(20)]) + self.assertEqual(list(qs), [c]) diff --git a/tests/reverse_lookup/__init__.py b/tests/reverse_lookup/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/reverse_lookup/models.py b/tests/reverse_lookup/models.py new file mode 100644 index 00000000..51e879bf --- /dev/null +++ b/tests/reverse_lookup/models.py @@ -0,0 +1,35 @@ +""" +Reverse lookups + +This demonstrates the reverse lookup features of the database API. +""" + +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +@python_2_unicode_compatible +class User(models.Model): + name = models.CharField(max_length=200) + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class Poll(models.Model): + question = models.CharField(max_length=200) + creator = models.ForeignKey(User, models.CASCADE) + + def __str__(self): + return self.question + + +@python_2_unicode_compatible +class Choice(models.Model): + name = models.CharField(max_length=100) + poll = models.ForeignKey(Poll, models.CASCADE, related_name="poll_choice") + related_poll = models.ForeignKey(Poll, models.CASCADE, related_name="related_choice") + + def __str__(self): + return self.name diff --git a/tests/reverse_lookup/tests.py b/tests/reverse_lookup/tests.py new file mode 100644 index 00000000..dda3c296 --- /dev/null +++ b/tests/reverse_lookup/tests.py @@ -0,0 +1,52 @@ +from __future__ import unicode_literals + +from django.core.exceptions import FieldError +from django.test import TestCase + +from .models import Choice, Poll, User + + +class ReverseLookupTests(TestCase): + + def setUp(self): + john = User.objects.create(name="John Doe") + jim = User.objects.create(name="Jim Bo") + first_poll = Poll.objects.create( + question="What's the first question?", + creator=john + ) + second_poll = Poll.objects.create( + question="What's the second question?", + creator=jim + ) + Choice.objects.create( + poll=first_poll, + related_poll=second_poll, + name="This is the answer." + ) + + def test_reverse_by_field(self): + u1 = User.objects.get( + poll__question__exact="What's the first question?" + ) + self.assertEqual(u1.name, "John Doe") + + u2 = User.objects.get( + poll__question__exact="What's the second question?" + ) + self.assertEqual(u2.name, "Jim Bo") + + def test_reverse_by_related_name(self): + p1 = Poll.objects.get(poll_choice__name__exact="This is the answer.") + self.assertEqual(p1.question, "What's the first question?") + + p2 = Poll.objects.get( + related_choice__name__exact="This is the answer.") + self.assertEqual(p2.question, "What's the second question?") + + def test_reverse_field_name_disallowed(self): + """ + If a related_name is given you can't use the field name instead + """ + with self.assertRaises(FieldError): + Poll.objects.get(choice__name__exact="This is the answer") diff --git a/tests/runtests.py b/tests/runtests.py new file mode 100644 index 00000000..8840ee1d --- /dev/null +++ b/tests/runtests.py @@ -0,0 +1,484 @@ +#!/usr/bin/env python +import argparse +import atexit +import copy +import os +import shutil +import subprocess +import sys +import tempfile +import warnings + +import django +from django.apps import apps +from django.conf import settings +from django.db import connection, connections +from django.test import TestCase, TransactionTestCase +from django.test.runner import default_test_processes +from django.test.selenium import SeleniumTestCaseBase +from django.test.utils import get_runner +from django.utils.deprecation import RemovedInDjango30Warning +from django.utils.log import DEFAULT_LOGGING + +try: + import MySQLdb +except ImportError: + pass +else: + # Ignore informational warnings from QuerySet.explain(). + warnings.filterwarnings('ignore', r'\(1003, *', category=MySQLdb.Warning) + +# Make deprecation warnings errors to ensure no usage of deprecated features. +warnings.simplefilter("error", RemovedInDjango30Warning) +# Make runtime warning errors to ensure no usage of error prone patterns. +warnings.simplefilter("error", RuntimeWarning) +# Ignore known warnings in test dependencies. +warnings.filterwarnings("ignore", "'U' mode is deprecated", DeprecationWarning, module='docutils.io') + +RUNTESTS_DIR = os.path.abspath(os.path.dirname(__file__)) + +TEMPLATE_DIR = os.path.join(RUNTESTS_DIR, 'templates') + +# Create a specific subdirectory for the duration of the test suite. +TMPDIR = tempfile.mkdtemp(prefix='django_') +# Set the TMPDIR environment variable in addition to tempfile.tempdir +# so that children processes inherit it. +tempfile.tempdir = os.environ['TMPDIR'] = TMPDIR + +# Removing the temporary TMPDIR. +atexit.register(shutil.rmtree, TMPDIR) + + +SUBDIRS_TO_SKIP = [ + 'data', + 'import_error_package', + 'test_runner_apps', +] + +ALWAYS_INSTALLED_APPS = [ + 'django.contrib.contenttypes', + 'django.contrib.auth', + 'django.contrib.sites', + 'django.contrib.sessions', + 'django.contrib.messages', + 'django.contrib.admin.apps.SimpleAdminConfig', + 'django.contrib.staticfiles', +] + +ALWAYS_MIDDLEWARE = [ + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.middleware.common.CommonMiddleware', + 'django.middleware.csrf.CsrfViewMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'django.contrib.messages.middleware.MessageMiddleware', +] + +# Need to add the associated contrib app to INSTALLED_APPS in some cases to +# avoid "RuntimeError: Model class X doesn't declare an explicit app_label +# and isn't in an application in INSTALLED_APPS." +CONTRIB_TESTS_TO_APPS = { + 'flatpages_tests': 'django.contrib.flatpages', + 'redirects_tests': 'django.contrib.redirects', +} + + +def get_test_modules(): + modules = [] + discovery_paths = [(None, RUNTESTS_DIR)] + if connection.features.gis_enabled: + # GIS tests are in nested apps + discovery_paths.append(('gis_tests', os.path.join(RUNTESTS_DIR, 'gis_tests'))) + else: + SUBDIRS_TO_SKIP.append('gis_tests') + + for modpath, dirpath in discovery_paths: + for f in os.listdir(dirpath): + if ('.' not in f and + os.path.basename(f) not in SUBDIRS_TO_SKIP and + not os.path.isfile(f) and + os.path.exists(os.path.join(dirpath, f, '__init__.py'))): + modules.append((modpath, f)) + return modules + + +def get_installed(): + return [app_config.name for app_config in apps.get_app_configs()] + + +def setup(verbosity, test_labels, parallel): + # Reduce the given test labels to just the app module path. + test_labels_set = set() + for label in test_labels: + bits = label.split('.')[:1] + test_labels_set.add('.'.join(bits)) + + if verbosity >= 1: + msg = "Testing against Django installed in '%s'" % os.path.dirname(django.__file__) + max_parallel = default_test_processes() if parallel == 0 else parallel + if max_parallel > 1: + msg += " with up to %d processes" % max_parallel + print(msg) + + # Force declaring available_apps in TransactionTestCase for faster tests. + def no_available_apps(self): + raise Exception("Please define available_apps in TransactionTestCase " + "and its subclasses.") + TransactionTestCase.available_apps = property(no_available_apps) + TestCase.available_apps = None + + state = { + 'INSTALLED_APPS': settings.INSTALLED_APPS, + 'ROOT_URLCONF': getattr(settings, "ROOT_URLCONF", ""), + 'TEMPLATES': settings.TEMPLATES, + 'LANGUAGE_CODE': settings.LANGUAGE_CODE, + 'STATIC_URL': settings.STATIC_URL, + 'STATIC_ROOT': settings.STATIC_ROOT, + 'MIDDLEWARE': settings.MIDDLEWARE, + } + + # Redirect some settings for the duration of these tests. + settings.INSTALLED_APPS = ALWAYS_INSTALLED_APPS + settings.ROOT_URLCONF = 'urls' + settings.STATIC_URL = '/static/' + settings.STATIC_ROOT = os.path.join(TMPDIR, 'static') + settings.TEMPLATES = [{ + 'BACKEND': 'django.template.backends.django.DjangoTemplates', + 'DIRS': [TEMPLATE_DIR], + 'APP_DIRS': True, + 'OPTIONS': { + 'context_processors': [ + 'django.template.context_processors.debug', + 'django.template.context_processors.request', + 'django.contrib.auth.context_processors.auth', + 'django.contrib.messages.context_processors.messages', + ], + }, + }] + settings.LANGUAGE_CODE = 'en' + settings.SITE_ID = 1 + settings.MIDDLEWARE = ALWAYS_MIDDLEWARE + settings.MIGRATION_MODULES = { + # This lets us skip creating migrations for the test models as many of + # them depend on one of the following contrib applications. + 'auth': None, + 'contenttypes': None, + 'sessions': None, + } + log_config = copy.deepcopy(DEFAULT_LOGGING) + # Filter out non-error logging so we don't have to capture it in lots of + # tests. + log_config['loggers']['django']['level'] = 'ERROR' + settings.LOGGING = log_config + settings.SILENCED_SYSTEM_CHECKS = [ + 'fields.W342', # ForeignKey(unique=True) -> OneToOneField + ] + + # Load all the ALWAYS_INSTALLED_APPS. + django.setup() + + # It would be nice to put this validation earlier but it must come after + # django.setup() so that connection.features.gis_enabled can be accessed + # without raising AppRegistryNotReady when running gis_tests in isolation + # on some backends (e.g. PostGIS). + if 'gis_tests' in test_labels_set and not connection.features.gis_enabled: + print('Aborting: A GIS database backend is required to run gis_tests.') + sys.exit(1) + + # Load all the test model apps. + test_modules = get_test_modules() + + installed_app_names = set(get_installed()) + for modpath, module_name in test_modules: + if modpath: + module_label = modpath + '.' + module_name + else: + module_label = module_name + # if the module (or an ancestor) was named on the command line, or + # no modules were named (i.e., run all), import + # this module and add it to INSTALLED_APPS. + module_found_in_labels = not test_labels or any( + # exact match or ancestor match + module_label == label or module_label.startswith(label + '.') + for label in test_labels_set + ) + + if module_name in CONTRIB_TESTS_TO_APPS and module_found_in_labels: + settings.INSTALLED_APPS.append(CONTRIB_TESTS_TO_APPS[module_name]) + + if module_found_in_labels and module_label not in installed_app_names: + if verbosity >= 2: + print("Importing application %s" % module_name) + settings.INSTALLED_APPS.append(module_label) + + # Add contrib.gis to INSTALLED_APPS if needed (rather than requiring + # @override_settings(INSTALLED_APPS=...) on all test cases. + gis = 'django.contrib.gis' + if connection.features.gis_enabled and gis not in settings.INSTALLED_APPS: + if verbosity >= 2: + print("Importing application %s" % gis) + settings.INSTALLED_APPS.append(gis) + + apps.set_installed_apps(settings.INSTALLED_APPS) + + return state + + +def teardown(state): + # Restore the old settings. + for key, value in state.items(): + setattr(settings, key, value) + # Discard the multiprocessing.util finalizer that tries to remove a + # temporary directory that's already removed by this script's + # atexit.register(shutil.rmtree, TMPDIR) handler. Prevents + # FileNotFoundError at the end of a test run on Python 3.6+ (#27890). + from multiprocessing.util import _finalizer_registry + _finalizer_registry.pop((-100, 0), None) + + +def actual_test_processes(parallel): + if parallel == 0: + # This doesn't work before django.setup() on some databases. + if all(conn.features.can_clone_databases for conn in connections.all()): + return default_test_processes() + else: + return 1 + else: + return parallel + + +class ActionSelenium(argparse.Action): + """ + Validate the comma-separated list of requested browsers. + """ + def __call__(self, parser, namespace, values, option_string=None): + browsers = values.split(',') + for browser in browsers: + try: + SeleniumTestCaseBase.import_webdriver(browser) + except ImportError: + raise argparse.ArgumentError(self, "Selenium browser specification '%s' is not valid." % browser) + setattr(namespace, self.dest, browsers) + + +def django_tests(verbosity, interactive, failfast, keepdb, reverse, + test_labels, debug_sql, parallel, tags, exclude_tags): + state = setup(verbosity, test_labels, parallel) + extra_tests = [] + + # Run the test suite, including the extra validation tests. + if not hasattr(settings, 'TEST_RUNNER'): + settings.TEST_RUNNER = 'django.test.runner.DiscoverRunner' + TestRunner = get_runner(settings) + + test_runner = TestRunner( + verbosity=verbosity, + interactive=interactive, + failfast=failfast, + keepdb=keepdb, + reverse=reverse, + debug_sql=debug_sql, + parallel=actual_test_processes(parallel), + tags=tags, + exclude_tags=exclude_tags, + ) + failures = test_runner.run_tests( + test_labels or get_installed(), + extra_tests=extra_tests, + ) + teardown(state) + return failures + + +def get_subprocess_args(options): + subprocess_args = [ + sys.executable, __file__, '--settings=%s' % options.settings + ] + if options.failfast: + subprocess_args.append('--failfast') + if options.verbosity: + subprocess_args.append('--verbosity=%s' % options.verbosity) + if not options.interactive: + subprocess_args.append('--noinput') + if options.tags: + subprocess_args.append('--tag=%s' % options.tags) + if options.exclude_tags: + subprocess_args.append('--exclude_tag=%s' % options.exclude_tags) + return subprocess_args + + +def bisect_tests(bisection_label, options, test_labels, parallel): + state = setup(options.verbosity, test_labels, parallel) + + test_labels = test_labels or get_installed() + + print('***** Bisecting test suite: %s' % ' '.join(test_labels)) + + # Make sure the bisection point isn't in the test list + # Also remove tests that need to be run in specific combinations + for label in [bisection_label, 'model_inheritance_same_model_name']: + try: + test_labels.remove(label) + except ValueError: + pass + + subprocess_args = get_subprocess_args(options) + + iteration = 1 + while len(test_labels) > 1: + midpoint = len(test_labels) // 2 + test_labels_a = test_labels[:midpoint] + [bisection_label] + test_labels_b = test_labels[midpoint:] + [bisection_label] + print('***** Pass %da: Running the first half of the test suite' % iteration) + print('***** Test labels: %s' % ' '.join(test_labels_a)) + failures_a = subprocess.call(subprocess_args + test_labels_a) + + print('***** Pass %db: Running the second half of the test suite' % iteration) + print('***** Test labels: %s' % ' '.join(test_labels_b)) + print('') + failures_b = subprocess.call(subprocess_args + test_labels_b) + + if failures_a and not failures_b: + print("***** Problem found in first half. Bisecting again...") + iteration += 1 + test_labels = test_labels_a[:-1] + elif failures_b and not failures_a: + print("***** Problem found in second half. Bisecting again...") + iteration += 1 + test_labels = test_labels_b[:-1] + elif failures_a and failures_b: + print("***** Multiple sources of failure found") + break + else: + print("***** No source of failure found... try pair execution (--pair)") + break + + if len(test_labels) == 1: + print("***** Source of error: %s" % test_labels[0]) + teardown(state) + + +def paired_tests(paired_test, options, test_labels, parallel): + state = setup(options.verbosity, test_labels, parallel) + + test_labels = test_labels or get_installed() + + print('***** Trying paired execution') + + # Make sure the constant member of the pair isn't in the test list + # Also remove tests that need to be run in specific combinations + for label in [paired_test, 'model_inheritance_same_model_name']: + try: + test_labels.remove(label) + except ValueError: + pass + + subprocess_args = get_subprocess_args(options) + + for i, label in enumerate(test_labels): + print('***** %d of %d: Check test pairing with %s' % ( + i + 1, len(test_labels), label)) + failures = subprocess.call(subprocess_args + [label, paired_test]) + if failures: + print('***** Found problem pair with %s' % label) + return + + print('***** No problem pair found') + teardown(state) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run the Django test suite.") + parser.add_argument( + 'modules', nargs='*', metavar='module', + help='Optional path(s) to test modules; e.g. "i18n" or ' + '"i18n.tests.TranslationTests.test_lazy_objects".', + ) + parser.add_argument( + '-v', '--verbosity', default=1, type=int, choices=[0, 1, 2, 3], + help='Verbosity level; 0=minimal output, 1=normal output, 2=all output', + ) + parser.add_argument( + '--noinput', action='store_false', dest='interactive', + help='Tells Django to NOT prompt the user for input of any kind.', + ) + parser.add_argument( + '--failfast', action='store_true', dest='failfast', + help='Tells Django to stop running the test suite after first failed test.', + ) + parser.add_argument( + '-k', '--keepdb', action='store_true', dest='keepdb', + help='Tells Django to preserve the test database between runs.', + ) + parser.add_argument( + '--settings', + help='Python path to settings module, e.g. "myproject.settings". If ' + 'this isn\'t provided, either the DJANGO_SETTINGS_MODULE ' + 'environment variable or "test_sqlite" will be used.', + ) + parser.add_argument( + '--bisect', + help='Bisect the test suite to discover a test that causes a test ' + 'failure when combined with the named test.', + ) + parser.add_argument( + '--pair', + help='Run the test suite in pairs with the named test to find problem pairs.', + ) + parser.add_argument( + '--reverse', action='store_true', + help='Sort test suites and test cases in opposite order to debug ' + 'test side effects not apparent with normal execution lineup.', + ) + parser.add_argument( + '--selenium', dest='selenium', action=ActionSelenium, metavar='BROWSERS', + help='A comma-separated list of browsers to run the Selenium tests against.', + ) + parser.add_argument( + '--debug-sql', action='store_true', dest='debug_sql', + help='Turn on the SQL query logger within tests.', + ) + parser.add_argument( + '--parallel', dest='parallel', nargs='?', default=0, type=int, + const=default_test_processes(), metavar='N', + help='Run tests using up to N parallel processes.', + ) + parser.add_argument( + '--tag', dest='tags', action='append', + help='Run only tests with the specified tags. Can be used multiple times.', + ) + parser.add_argument( + '--exclude-tag', dest='exclude_tags', action='append', + help='Do not run tests with the specified tag. Can be used multiple times.', + ) + + options = parser.parse_args() + + # Allow including a trailing slash on app_labels for tab completion convenience + options.modules = [os.path.normpath(labels) for labels in options.modules] + + if options.settings: + os.environ['DJANGO_SETTINGS_MODULE'] = options.settings + else: + os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'test_mssql') + options.settings = os.environ['DJANGO_SETTINGS_MODULE'] + + if options.selenium: + if not options.tags: + options.tags = ['selenium'] + elif 'selenium' not in options.tags: + options.tags.append('selenium') + SeleniumTestCaseBase.browsers = options.selenium + + if options.bisect: + bisect_tests(options.bisect, options, options.modules, options.parallel) + elif options.pair: + paired_tests(options.pair, options, options.modules, options.parallel) + else: + failures = django_tests( + options.verbosity, options.interactive, options.failfast, + options.keepdb, options.reverse, options.modules, + options.debug_sql, options.parallel, options.tags, + options.exclude_tags, + ) + if failures: + sys.exit(1) diff --git a/tests/select_for_update/__init__.py b/tests/select_for_update/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/select_for_update/models.py b/tests/select_for_update/models.py new file mode 100644 index 00000000..b8154af3 --- /dev/null +++ b/tests/select_for_update/models.py @@ -0,0 +1,20 @@ +from django.db import models + + +class Country(models.Model): + name = models.CharField(max_length=30) + + +class City(models.Model): + name = models.CharField(max_length=30) + country = models.ForeignKey(Country, models.CASCADE) + + +class Person(models.Model): + name = models.CharField(max_length=30) + born = models.ForeignKey(City, models.CASCADE, related_name='+') + died = models.ForeignKey(City, models.CASCADE, related_name='+') + + +class PersonProfile(models.Model): + person = models.OneToOneField(Person, models.CASCADE, related_name='profile') diff --git a/tests/select_for_update/tests.py b/tests/select_for_update/tests.py new file mode 100644 index 00000000..3f68057b --- /dev/null +++ b/tests/select_for_update/tests.py @@ -0,0 +1,447 @@ +import threading +import time +from unittest import mock + +from multiple_database.routers import TestRouter + +from django.core.exceptions import FieldError +from django.db import ( + DatabaseError, NotSupportedError, connection, connections, router, + transaction, +) +from django.test import ( + TransactionTestCase, override_settings, skipIfDBFeature, + skipUnlessDBFeature, +) +from django.test.utils import CaptureQueriesContext + +from .models import City, Country, Person, PersonProfile + + +class SelectForUpdateTests(TransactionTestCase): + + available_apps = ['select_for_update'] + + def setUp(self): + # This is executed in autocommit mode so that code in + # run_select_for_update can see this data. + self.country1 = Country.objects.create(name='Belgium') + self.country2 = Country.objects.create(name='France') + self.city1 = City.objects.create(name='Liberchies', country=self.country1) + self.city2 = City.objects.create(name='Samois-sur-Seine', country=self.country2) + self.person = Person.objects.create(name='Reinhardt', born=self.city1, died=self.city2) + self.person_profile = PersonProfile.objects.create(person=self.person) + + # We need another database connection in transaction to test that one + # connection issuing a SELECT ... FOR UPDATE will block. + self.new_connection = connection.copy() + + def tearDown(self): + try: + self.end_blocking_transaction() + except (DatabaseError, AttributeError): + pass + self.new_connection.close() + + def start_blocking_transaction(self): + self.new_connection.set_autocommit(False) + # Start a blocking transaction. At some point, + # end_blocking_transaction() should be called. + self.cursor = self.new_connection.cursor() + sql = 'SELECT * FROM %(db_table)s %(for_update)s;' % { + 'db_table': Person._meta.db_table, + 'for_update': self.new_connection.ops.for_update_sql(), + } + self.cursor.execute(sql, ()) + self.cursor.fetchone() + + def end_blocking_transaction(self): + # Roll back the blocking transaction. + self.cursor.close() + self.new_connection.rollback() + self.new_connection.set_autocommit(True) + + def has_for_update_sql(self, queries, **kwargs): + # Examine the SQL that was executed to determine whether it + # contains the 'SELECT..FOR UPDATE' stanza. + for_update_sql = connection.ops.for_update_sql(**kwargs) + return any(for_update_sql in query['sql'] for query in queries) + + @skipUnlessDBFeature('has_select_for_update') + def test_for_update_sql_generated(self): + """ + The backend's FOR UPDATE variant appears in + generated SQL when select_for_update is invoked. + """ + with transaction.atomic(), CaptureQueriesContext(connection) as ctx: + list(Person.objects.all().select_for_update()) + self.assertTrue(self.has_for_update_sql(ctx.captured_queries)) + + @skipUnlessDBFeature('has_select_for_update_nowait') + def test_for_update_sql_generated_nowait(self): + """ + The backend's FOR UPDATE NOWAIT variant appears in + generated SQL when select_for_update is invoked. + """ + with transaction.atomic(), CaptureQueriesContext(connection) as ctx: + list(Person.objects.all().select_for_update(nowait=True)) + self.assertTrue(self.has_for_update_sql(ctx.captured_queries, nowait=True)) + + @skipUnlessDBFeature('has_select_for_update_skip_locked') + def test_for_update_sql_generated_skip_locked(self): + """ + The backend's FOR UPDATE SKIP LOCKED variant appears in + generated SQL when select_for_update is invoked. + """ + with transaction.atomic(), CaptureQueriesContext(connection) as ctx: + list(Person.objects.all().select_for_update(skip_locked=True)) + self.assertTrue(self.has_for_update_sql(ctx.captured_queries, skip_locked=True)) + + @skipUnlessDBFeature('has_select_for_update_of') + def test_for_update_sql_generated_of(self): + """ + The backend's FOR UPDATE OF variant appears in the generated SQL when + select_for_update() is invoked. + """ + with transaction.atomic(), CaptureQueriesContext(connection) as ctx: + list(Person.objects.select_related( + 'born__country', + ).select_for_update( + of=('born__country',), + ).select_for_update( + of=('self', 'born__country') + )) + features = connections['default'].features + if features.select_for_update_of_column: + expected = ['"select_for_update_person"."id"', '"select_for_update_country"."id"'] + else: + expected = ['"select_for_update_person"', '"select_for_update_country"'] + if features.uppercases_column_names: + expected = [value.upper() for value in expected] + self.assertTrue(self.has_for_update_sql(ctx.captured_queries, of=expected)) + + @skipUnlessDBFeature('has_select_for_update_of') + def test_for_update_of_followed_by_values(self): + with transaction.atomic(): + values = list(Person.objects.select_for_update(of=('self',)).values('pk')) + self.assertEqual(values, [{'pk': self.person.pk}]) + + @skipUnlessDBFeature('has_select_for_update_of') + def test_for_update_of_followed_by_values_list(self): + with transaction.atomic(): + values = list(Person.objects.select_for_update(of=('self',)).values_list('pk')) + self.assertEqual(values, [(self.person.pk,)]) + + @skipUnlessDBFeature('has_select_for_update_of') + def test_for_update_of_self_when_self_is_not_selected(self): + """ + select_for_update(of=['self']) when the only columns selected are from + related tables. + """ + with transaction.atomic(): + values = list(Person.objects.select_related('born').select_for_update(of=('self',)).values('born__name')) + self.assertEqual(values, [{'born__name': self.city1.name}]) + + @skipUnlessDBFeature('has_select_for_update_nowait') + def test_nowait_raises_error_on_block(self): + """ + If nowait is specified, we expect an error to be raised rather + than blocking. + """ + self.start_blocking_transaction() + status = [] + + thread = threading.Thread( + target=self.run_select_for_update, + args=(status,), + kwargs={'nowait': True}, + ) + + thread.start() + time.sleep(1) + thread.join() + self.end_blocking_transaction() + self.assertIsInstance(status[-1], DatabaseError) + + @skipUnlessDBFeature('has_select_for_update_skip_locked') + def test_skip_locked_skips_locked_rows(self): + """ + If skip_locked is specified, the locked row is skipped resulting in + Person.DoesNotExist. + """ + self.start_blocking_transaction() + status = [] + thread = threading.Thread( + target=self.run_select_for_update, + args=(status,), + kwargs={'skip_locked': True}, + ) + thread.start() + time.sleep(1) + thread.join() + self.end_blocking_transaction() + self.assertIsInstance(status[-1], Person.DoesNotExist) + + @skipIfDBFeature('has_select_for_update_nowait') + @skipUnlessDBFeature('has_select_for_update') + def test_unsupported_nowait_raises_error(self): + """ + NotSupportedError is raised if a SELECT...FOR UPDATE NOWAIT is run on + a database backend that supports FOR UPDATE but not NOWAIT. + """ + with self.assertRaisesMessage(NotSupportedError, 'NOWAIT is not supported on this database backend.'): + with transaction.atomic(): + Person.objects.select_for_update(nowait=True).get() + + @skipIfDBFeature('has_select_for_update_skip_locked') + @skipUnlessDBFeature('has_select_for_update') + def test_unsupported_skip_locked_raises_error(self): + """ + NotSupportedError is raised if a SELECT...FOR UPDATE SKIP LOCKED is run + on a database backend that supports FOR UPDATE but not SKIP LOCKED. + """ + with self.assertRaisesMessage(NotSupportedError, 'SKIP LOCKED is not supported on this database backend.'): + with transaction.atomic(): + Person.objects.select_for_update(skip_locked=True).get() + + @skipIfDBFeature('has_select_for_update_of') + @skipUnlessDBFeature('has_select_for_update') + def test_unsupported_of_raises_error(self): + """ + NotSupportedError is raised if a SELECT...FOR UPDATE OF... is run on + a database backend that supports FOR UPDATE but not OF. + """ + msg = 'FOR UPDATE OF is not supported on this database backend.' + with self.assertRaisesMessage(NotSupportedError, msg): + with transaction.atomic(): + Person.objects.select_for_update(of=('self',)).get() + + @skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of') + def test_unrelated_of_argument_raises_error(self): + """ + FieldError is raised if a non-relation field is specified in of=(...). + """ + msg = ( + 'Invalid field name(s) given in select_for_update(of=(...)): %s. ' + 'Only relational fields followed in the query are allowed. ' + 'Choices are: self, born, born__country.' + ) + invalid_of = [ + ('nonexistent',), + ('name',), + ('born__nonexistent',), + ('born__name',), + ('born__nonexistent', 'born__name'), + ] + for of in invalid_of: + with self.subTest(of=of): + with self.assertRaisesMessage(FieldError, msg % ', '.join(of)): + with transaction.atomic(): + Person.objects.select_related('born__country').select_for_update(of=of).get() + + @skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of') + def test_related_but_unselected_of_argument_raises_error(self): + """ + FieldError is raised if a relation field that is not followed in the + query is specified in of=(...). + """ + msg = ( + 'Invalid field name(s) given in select_for_update(of=(...)): %s. ' + 'Only relational fields followed in the query are allowed. ' + 'Choices are: self, born, profile.' + ) + for name in ['born__country', 'died', 'died__country']: + with self.subTest(name=name): + with self.assertRaisesMessage(FieldError, msg % name): + with transaction.atomic(): + Person.objects.select_related( + 'born', 'profile', + ).exclude(profile=None).select_for_update(of=(name,)).get() + + @skipUnlessDBFeature('has_select_for_update', 'has_select_for_update_of') + def test_reverse_one_to_one_of_arguments(self): + """ + Reverse OneToOneFields may be included in of=(...) as long as NULLs + are excluded because LEFT JOIN isn't allowed in SELECT FOR UPDATE. + """ + with transaction.atomic(): + person = Person.objects.select_related( + 'profile', + ).exclude(profile=None).select_for_update(of=('profile',)).get() + self.assertEqual(person.profile, self.person_profile) + + #@skipUnlessDBFeature('has_select_for_update') + #def test_for_update_after_from(self): + # features_class = connections['default'].features.__class__ + # attribute_to_patch = "%s.%s.for_update_after_from" % (features_class.__module__, features_class.__name__) + # with mock.patch(attribute_to_patch, return_value=True): + # with transaction.atomic(): + # self.assertIn('FOR UPDATE WHERE', str(Person.objects.filter(name='foo').select_for_update().query)) + + @skipUnlessDBFeature('has_select_for_update') + def test_for_update_requires_transaction(self): + """ + A TransactionManagementError is raised + when a select_for_update query is executed outside of a transaction. + """ + msg = 'select_for_update cannot be used outside of a transaction.' + with self.assertRaisesMessage(transaction.TransactionManagementError, msg): + list(Person.objects.all().select_for_update()) + + @skipUnlessDBFeature('has_select_for_update') + def test_for_update_requires_transaction_only_in_execution(self): + """ + No TransactionManagementError is raised + when select_for_update is invoked outside of a transaction - + only when the query is executed. + """ + people = Person.objects.all().select_for_update() + msg = 'select_for_update cannot be used outside of a transaction.' + with self.assertRaisesMessage(transaction.TransactionManagementError, msg): + list(people) + + @skipUnlessDBFeature('supports_select_for_update_with_limit') + def test_select_for_update_with_limit(self): + other = Person.objects.create(name='Grappeli', born=self.city1, died=self.city2) + with transaction.atomic(): + qs = list(Person.objects.all().order_by('pk').select_for_update()[1:2]) + self.assertEqual(qs[0], other) + + @skipIfDBFeature('supports_select_for_update_with_limit') + def test_unsupported_select_for_update_with_limit(self): + msg = 'LIMIT/OFFSET is not supported with select_for_update on this database backend.' + with self.assertRaisesMessage(NotSupportedError, msg): + with transaction.atomic(): + list(Person.objects.all().order_by('pk').select_for_update()[1:2]) + + def run_select_for_update(self, status, **kwargs): + """ + Utility method that runs a SELECT FOR UPDATE against all + Person instances. After the select_for_update, it attempts + to update the name of the only record, save, and commit. + + This function expects to run in a separate thread. + """ + status.append('started') + try: + # We need to enter transaction management again, as this is done on + # per-thread basis + with transaction.atomic(): + person = Person.objects.select_for_update(**kwargs).get() + person.name = 'Fred' + person.save() + except (DatabaseError, Person.DoesNotExist) as e: + status.append(e) + finally: + # This method is run in a separate thread. It uses its own + # database connection. Close it without waiting for the GC. + connection.close() + + @skipUnlessDBFeature('has_select_for_update') + @skipUnlessDBFeature('supports_transactions') + def test_block(self): + """ + A thread running a select_for_update that accesses rows being touched + by a similar operation on another connection blocks correctly. + """ + # First, let's start the transaction in our thread. + self.start_blocking_transaction() + + # Now, try it again using the ORM's select_for_update + # facility. Do this in a separate thread. + status = [] + thread = threading.Thread( + target=self.run_select_for_update, args=(status,) + ) + + # The thread should immediately block, but we'll sleep + # for a bit to make sure. + thread.start() + sanity_count = 0 + while len(status) != 1 and sanity_count < 10: + sanity_count += 1 + time.sleep(1) + if sanity_count >= 10: + raise ValueError('Thread did not run and block') + + # Check the person hasn't been updated. Since this isn't + # using FOR UPDATE, it won't block. + p = Person.objects.get(pk=self.person.pk) + self.assertEqual('Reinhardt', p.name) + + # When we end our blocking transaction, our thread should + # be able to continue. + self.end_blocking_transaction() + thread.join(5.0) + + # Check the thread has finished. Assuming it has, we should + # find that it has updated the person's name. + self.assertFalse(thread.isAlive()) + + # We must commit the transaction to ensure that MySQL gets a fresh read, + # since by default it runs in REPEATABLE READ mode + transaction.commit() + + p = Person.objects.get(pk=self.person.pk) + self.assertEqual('Fred', p.name) + + @skipUnlessDBFeature('has_select_for_update') + def test_raw_lock_not_available(self): + """ + Running a raw query which can't obtain a FOR UPDATE lock raises + the correct exception + """ + self.start_blocking_transaction() + + def raw(status): + try: + list( + Person.objects.raw( + 'SELECT * FROM %s %s' % ( + Person._meta.db_table, + connection.ops.for_update_sql(nowait=True) + ) + ) + ) + except DatabaseError as e: + status.append(e) + finally: + # This method is run in a separate thread. It uses its own + # database connection. Close it without waiting for the GC. + # Connection cannot be closed on Oracle because cursor is still + # open. + if connection.vendor != 'oracle': + connection.close() + + status = [] + thread = threading.Thread(target=raw, kwargs={'status': status}) + thread.start() + time.sleep(1) + thread.join() + self.end_blocking_transaction() + self.assertIsInstance(status[-1], DatabaseError) + + @skipUnlessDBFeature('has_select_for_update') + @override_settings(DATABASE_ROUTERS=[TestRouter()]) + def test_select_for_update_on_multidb(self): + query = Person.objects.select_for_update() + self.assertEqual(router.db_for_write(Person), query.db) + + @skipUnlessDBFeature('has_select_for_update') + def test_select_for_update_with_get(self): + with transaction.atomic(): + person = Person.objects.select_for_update().get(name='Reinhardt') + self.assertEqual(person.name, 'Reinhardt') + + def test_nowait_and_skip_locked(self): + with self.assertRaisesMessage(ValueError, 'The nowait option cannot be used with skip_locked.'): + Person.objects.select_for_update(nowait=True, skip_locked=True) + + def test_ordered_select_for_update(self): + """ + Subqueries should respect ordering as an ORDER BY clause may be useful + to specify a row locking order to prevent deadlocks (#27193). + """ + with transaction.atomic(): + qs = Person.objects.filter(id__in=Person.objects.order_by('-id').select_for_update()) + self.assertIn('ORDER BY', str(qs.query)) diff --git a/tests/select_related/__init__.py b/tests/select_related/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/select_related/models.py b/tests/select_related/models.py new file mode 100644 index 00000000..26bf34eb --- /dev/null +++ b/tests/select_related/models.py @@ -0,0 +1,138 @@ +""" +Tests for select_related() + +``select_related()`` follows all relationships and pre-caches any foreign key +values so that complex trees can be fetched in a single query. However, this +isn't always a good idea, so the ``depth`` argument control how many "levels" +the select-related behavior will traverse. +""" + +from django.contrib.contenttypes.fields import ( + GenericForeignKey, GenericRelation, +) +from django.contrib.contenttypes.models import ContentType +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + +# Who remembers high school biology? + + +@python_2_unicode_compatible +class Domain(models.Model): + name = models.CharField(max_length=50) + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class Kingdom(models.Model): + name = models.CharField(max_length=50) + domain = models.ForeignKey(Domain, models.CASCADE) + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class Phylum(models.Model): + name = models.CharField(max_length=50) + kingdom = models.ForeignKey(Kingdom, models.CASCADE) + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class Klass(models.Model): + name = models.CharField(max_length=50) + phylum = models.ForeignKey(Phylum, models.CASCADE) + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class Order(models.Model): + name = models.CharField(max_length=50) + klass = models.ForeignKey(Klass, models.CASCADE) + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class Family(models.Model): + name = models.CharField(max_length=50) + order = models.ForeignKey(Order, models.CASCADE) + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class Genus(models.Model): + name = models.CharField(max_length=50) + family = models.ForeignKey(Family, models.CASCADE) + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class Species(models.Model): + name = models.CharField(max_length=50) + genus = models.ForeignKey(Genus, models.CASCADE) + + def __str__(self): + return self.name + +# and we'll invent a new thing so we have a model with two foreign keys + + +@python_2_unicode_compatible +class HybridSpecies(models.Model): + name = models.CharField(max_length=50) + parent_1 = models.ForeignKey(Species, models.CASCADE, related_name='child_1') + parent_2 = models.ForeignKey(Species, models.CASCADE, related_name='child_2') + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class Topping(models.Model): + name = models.CharField(max_length=30) + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class Pizza(models.Model): + name = models.CharField(max_length=100) + toppings = models.ManyToManyField(Topping) + + def __str__(self): + return self.name + + +@python_2_unicode_compatible +class TaggedItem(models.Model): + tag = models.CharField(max_length=30) + + content_type = models.ForeignKey(ContentType, models.CASCADE, related_name='select_related_tagged_items') + object_id = models.PositiveIntegerField() + content_object = GenericForeignKey('content_type', 'object_id') + + def __str__(self): + return self.tag + + +@python_2_unicode_compatible +class Bookmark(models.Model): + url = models.URLField() + tags = GenericRelation(TaggedItem) + + def __str__(self): + return self.url diff --git a/tests/select_related/tests.py b/tests/select_related/tests.py new file mode 100644 index 00000000..9700d10d --- /dev/null +++ b/tests/select_related/tests.py @@ -0,0 +1,223 @@ +from __future__ import unicode_literals + +from django.core.exceptions import FieldError +from django.test import SimpleTestCase, TestCase + +from .models import ( + Bookmark, Domain, Family, Genus, HybridSpecies, Kingdom, Klass, Order, + Phylum, Pizza, Species, TaggedItem, +) + + +class SelectRelatedTests(TestCase): + + @classmethod + def create_tree(cls, stringtree): + """ + Helper to create a complete tree. + """ + names = stringtree.split() + models = [Domain, Kingdom, Phylum, Klass, Order, Family, Genus, Species] + assert len(names) == len(models), (names, models) + + parent = None + for name, model in zip(names, models): + try: + obj = model.objects.get(name=name) + except model.DoesNotExist: + obj = model(name=name) + if parent: + setattr(obj, parent.__class__.__name__.lower(), parent) + obj.save() + parent = obj + + @classmethod + def setUpTestData(cls): + cls.create_tree("Eukaryota Animalia Anthropoda Insecta Diptera Drosophilidae Drosophila melanogaster") + cls.create_tree("Eukaryota Animalia Chordata Mammalia Primates Hominidae Homo sapiens") + cls.create_tree("Eukaryota Plantae Magnoliophyta Magnoliopsida Fabales Fabaceae Pisum sativum") + cls.create_tree("Eukaryota Fungi Basidiomycota Homobasidiomycatae Agaricales Amanitacae Amanita muscaria") + + def test_access_fks_without_select_related(self): + """ + Normally, accessing FKs doesn't fill in related objects + """ + with self.assertNumQueries(8): + fly = Species.objects.get(name="melanogaster") + domain = fly.genus.family.order.klass.phylum.kingdom.domain + self.assertEqual(domain.name, 'Eukaryota') + + def test_access_fks_with_select_related(self): + """ + A select_related() call will fill in those related objects without any + extra queries + """ + with self.assertNumQueries(1): + person = ( + Species.objects + .select_related('genus__family__order__klass__phylum__kingdom__domain') + .get(name="sapiens") + ) + domain = person.genus.family.order.klass.phylum.kingdom.domain + self.assertEqual(domain.name, 'Eukaryota') + + def test_list_without_select_related(self): + """ + select_related() also of course applies to entire lists, not just + items. This test verifies the expected behavior without select_related. + """ + with self.assertNumQueries(9): + world = Species.objects.all() + families = [o.genus.family.name for o in world] + self.assertEqual(sorted(families), [ + 'Amanitacae', + 'Drosophilidae', + 'Fabaceae', + 'Hominidae', + ]) + + def test_list_with_select_related(self): + """ + select_related() also of course applies to entire lists, not just + items. This test verifies the expected behavior with select_related. + """ + with self.assertNumQueries(1): + world = Species.objects.all().select_related() + families = [o.genus.family.name for o in world] + self.assertEqual(sorted(families), [ + 'Amanitacae', + 'Drosophilidae', + 'Fabaceae', + 'Hominidae', + ]) + + def test_list_with_depth(self): + """ + Passing a relationship field lookup specifier to select_related() will + stop the descent at a particular level. This can be used on lists as + well. + """ + with self.assertNumQueries(5): + world = Species.objects.all().select_related('genus__family') + orders = [o.genus.family.order.name for o in world] + self.assertEqual(sorted(orders), ['Agaricales', 'Diptera', 'Fabales', 'Primates']) + + def test_select_related_with_extra(self): + s = (Species.objects.all() + .select_related() + .extra(select={'a': 'select_related_species.id + 10'})[0]) + self.assertEqual(s.id + 10, s.a) + + def test_certain_fields(self): + """ + The optional fields passed to select_related() control which related + models we pull in. This allows for smaller queries. + + In this case, we explicitly say to select the 'genus' and + 'genus.family' models, leading to the same number of queries as before. + """ + with self.assertNumQueries(1): + world = Species.objects.select_related('genus__family') + families = [o.genus.family.name for o in world] + self.assertEqual(sorted(families), ['Amanitacae', 'Drosophilidae', 'Fabaceae', 'Hominidae']) + + def test_more_certain_fields(self): + """ + In this case, we explicitly say to select the 'genus' and + 'genus.family' models, leading to the same number of queries as before. + """ + with self.assertNumQueries(2): + world = Species.objects.filter(genus__name='Amanita')\ + .select_related('genus__family') + orders = [o.genus.family.order.name for o in world] + self.assertEqual(orders, ['Agaricales']) + + def test_field_traversal(self): + with self.assertNumQueries(1): + s = (Species.objects.all() + .select_related('genus__family__order') + .order_by('id')[0:1].get().genus.family.order.name) + self.assertEqual(s, 'Diptera') + + def test_depth_fields_fails(self): + with self.assertRaises(TypeError): + Species.objects.select_related('genus__family__order', depth=4) + + def test_none_clears_list(self): + queryset = Species.objects.select_related('genus').select_related(None) + self.assertIs(queryset.query.select_related, False) + + def test_chaining(self): + parent_1, parent_2 = Species.objects.all()[:2] + HybridSpecies.objects.create(name='hybrid', parent_1=parent_1, parent_2=parent_2) + queryset = HybridSpecies.objects.select_related('parent_1').select_related('parent_2') + with self.assertNumQueries(1): + obj = queryset[0] + self.assertEqual(obj.parent_1, parent_1) + self.assertEqual(obj.parent_2, parent_2) + + def test_select_related_after_values(self): + """ + Running select_related() after calling values() raises a TypeError + """ + message = "Cannot call select_related() after .values() or .values_list()" + with self.assertRaisesMessage(TypeError, message): + list(Species.objects.values('name').select_related('genus')) + + def test_select_related_after_values_list(self): + """ + Running select_related() after calling values_list() raises a TypeError + """ + message = "Cannot call select_related() after .values() or .values_list()" + with self.assertRaisesMessage(TypeError, message): + list(Species.objects.values_list('name').select_related('genus')) + + +class SelectRelatedValidationTests(SimpleTestCase): + """ + select_related() should thrown an error on fields that do not exist and + non-relational fields. + """ + non_relational_error = "Non-relational field given in select_related: '%s'. Choices are: %s" + invalid_error = "Invalid field name(s) given in select_related: '%s'. Choices are: %s" + + def test_non_relational_field(self): + with self.assertRaisesMessage(FieldError, self.non_relational_error % ('name', 'genus')): + list(Species.objects.select_related('name__some_field')) + + with self.assertRaisesMessage(FieldError, self.non_relational_error % ('name', 'genus')): + list(Species.objects.select_related('name')) + + with self.assertRaisesMessage(FieldError, self.non_relational_error % ('name', '(none)')): + list(Domain.objects.select_related('name')) + + def test_non_relational_field_nested(self): + # TODO: fix + return + with self.assertRaisesMessage(FieldError, self.non_relational_error % ('name', 'family')): + list(Species.objects.select_related('genus__name')) + + def test_many_to_many_field(self): + with self.assertRaisesMessage(FieldError, self.invalid_error % ('toppings', '(none)')): + list(Pizza.objects.select_related('toppings')) + + def test_reverse_relational_field(self): + with self.assertRaisesMessage(FieldError, self.invalid_error % ('child_1', 'genus')): + list(Species.objects.select_related('child_1')) + + def test_invalid_field(self): + with self.assertRaisesMessage(FieldError, self.invalid_error % ('invalid_field', 'genus')): + list(Species.objects.select_related('invalid_field')) + + with self.assertRaisesMessage(FieldError, self.invalid_error % ('related_invalid_field', 'family')): + list(Species.objects.select_related('genus__related_invalid_field')) + + with self.assertRaisesMessage(FieldError, self.invalid_error % ('invalid_field', '(none)')): + list(Domain.objects.select_related('invalid_field')) + + def test_generic_relations(self): + with self.assertRaisesMessage(FieldError, self.invalid_error % ('tags', '')): + list(Bookmark.objects.select_related('tags')) + + with self.assertRaisesMessage(FieldError, self.invalid_error % ('content_object', 'content_type')): + list(TaggedItem.objects.select_related('content_object')) diff --git a/tests/select_related_onetoone/__init__.py b/tests/select_related_onetoone/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/select_related_onetoone/models.py b/tests/select_related_onetoone/models.py new file mode 100644 index 00000000..0a20cc73 --- /dev/null +++ b/tests/select_related_onetoone/models.py @@ -0,0 +1,103 @@ +from django.db import models + + +class User(models.Model): + username = models.CharField(max_length=100) + email = models.EmailField() + + def __str__(self): + return self.username + + +class UserProfile(models.Model): + user = models.OneToOneField(User, models.CASCADE) + city = models.CharField(max_length=100) + state = models.CharField(max_length=2) + + def __str__(self): + return "%s, %s" % (self.city, self.state) + + +class UserStatResult(models.Model): + results = models.CharField(max_length=50) + + def __str__(self): + return 'UserStatResults, results = %s' % (self.results,) + + +class UserStat(models.Model): + user = models.OneToOneField(User, models.CASCADE, primary_key=True) + posts = models.IntegerField() + results = models.ForeignKey(UserStatResult, models.CASCADE) + + def __str__(self): + return 'UserStat, posts = %s' % (self.posts,) + + +class StatDetails(models.Model): + base_stats = models.OneToOneField(UserStat, models.CASCADE) + comments = models.IntegerField() + + def __str__(self): + return 'StatDetails, comments = %s' % (self.comments,) + + +class AdvancedUserStat(UserStat): + karma = models.IntegerField() + + +class Image(models.Model): + name = models.CharField(max_length=100) + + +class Product(models.Model): + name = models.CharField(max_length=100) + image = models.OneToOneField(Image, models.SET_NULL, null=True) + + +class Parent1(models.Model): + name1 = models.CharField(max_length=50) + + def __str__(self): + return self.name1 + + +class Parent2(models.Model): + # Avoid having two "id" fields in the Child1 subclass + id2 = models.AutoField(primary_key=True) + name2 = models.CharField(max_length=50) + + def __str__(self): + return self.name2 + + +class Child1(Parent1, Parent2): + value = models.IntegerField() + + def __str__(self): + return self.name1 + + +class Child2(Parent1): + parent2 = models.OneToOneField(Parent2, models.CASCADE) + value = models.IntegerField() + + def __str__(self): + return self.name1 + + +class Child3(Child2): + value3 = models.IntegerField() + + +class Child4(Child1): + value4 = models.IntegerField() + + +class LinkedList(models.Model): + name = models.CharField(max_length=50) + previous_item = models.OneToOneField( + 'self', models.CASCADE, + related_name='next_item', + blank=True, null=True, + ) diff --git a/tests/select_related_onetoone/tests.py b/tests/select_related_onetoone/tests.py new file mode 100644 index 00000000..0438257a --- /dev/null +++ b/tests/select_related_onetoone/tests.py @@ -0,0 +1,236 @@ +from django.core.exceptions import FieldError +from django.db.models import FilteredRelation +from django.test import SimpleTestCase, TestCase + +from .models import ( + AdvancedUserStat, Child1, Child2, Child3, Child4, Image, LinkedList, + Parent1, Parent2, Product, StatDetails, User, UserProfile, UserStat, + UserStatResult, +) + + +class ReverseSelectRelatedTestCase(TestCase): + def setUp(self): + user = User.objects.create(username="test") + UserProfile.objects.create(user=user, state="KS", city="Lawrence") + results = UserStatResult.objects.create(results='first results') + userstat = UserStat.objects.create(user=user, posts=150, results=results) + StatDetails.objects.create(base_stats=userstat, comments=259) + + user2 = User.objects.create(username="bob") + results2 = UserStatResult.objects.create(results='moar results') + advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5, results=results2) + StatDetails.objects.create(base_stats=advstat, comments=250) + p1 = Parent1(name1="Only Parent1") + p1.save() + c1 = Child1(name1="Child1 Parent1", name2="Child1 Parent2", value=1) + c1.save() + p2 = Parent2(name2="Child2 Parent2") + p2.save() + c2 = Child2(name1="Child2 Parent1", parent2=p2, value=2) + c2.save() + + def test_basic(self): + with self.assertNumQueries(1): + u = User.objects.select_related("userprofile").get(username="test") + self.assertEqual(u.userprofile.state, "KS") + + def test_follow_next_level(self): + with self.assertNumQueries(1): + u = User.objects.select_related("userstat__results").get(username="test") + self.assertEqual(u.userstat.posts, 150) + self.assertEqual(u.userstat.results.results, 'first results') + + def test_follow_two(self): + with self.assertNumQueries(1): + u = User.objects.select_related("userprofile", "userstat").get(username="test") + self.assertEqual(u.userprofile.state, "KS") + self.assertEqual(u.userstat.posts, 150) + + def test_follow_two_next_level(self): + with self.assertNumQueries(1): + u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test") + self.assertEqual(u.userstat.results.results, 'first results') + self.assertEqual(u.userstat.statdetails.comments, 259) + + def test_forward_and_back(self): + with self.assertNumQueries(1): + stat = UserStat.objects.select_related("user__userprofile").get(user__username="test") + self.assertEqual(stat.user.userprofile.state, 'KS') + self.assertEqual(stat.user.userstat.posts, 150) + + def test_back_and_forward(self): + with self.assertNumQueries(1): + u = User.objects.select_related("userstat").get(username="test") + self.assertEqual(u.userstat.user.username, 'test') + + def test_not_followed_by_default(self): + with self.assertNumQueries(2): + u = User.objects.select_related().get(username="test") + self.assertEqual(u.userstat.posts, 150) + + def test_follow_from_child_class(self): + with self.assertNumQueries(1): + stat = AdvancedUserStat.objects.select_related('user', 'statdetails').get(posts=200) + self.assertEqual(stat.statdetails.comments, 250) + self.assertEqual(stat.user.username, 'bob') + + def test_follow_inheritance(self): + with self.assertNumQueries(1): + stat = UserStat.objects.select_related('user', 'advanceduserstat').get(posts=200) + self.assertEqual(stat.advanceduserstat.posts, 200) + self.assertEqual(stat.user.username, 'bob') + with self.assertNumQueries(0): + self.assertEqual(stat.advanceduserstat.user.username, 'bob') + + def test_nullable_relation(self): + im = Image.objects.create(name="imag1") + p1 = Product.objects.create(name="Django Plushie", image=im) + p2 = Product.objects.create(name="Talking Django Plushie") + + with self.assertNumQueries(1): + result = sorted(Product.objects.select_related("image"), key=lambda x: x.name) + self.assertEqual([p.name for p in result], ["Django Plushie", "Talking Django Plushie"]) + + self.assertEqual(p1.image, im) + # Check for ticket #13839 + self.assertIsNone(p2.image) + + def test_missing_reverse(self): + """ + Ticket #13839: select_related() should NOT cache None + for missing objects on a reverse 1-1 relation. + """ + with self.assertNumQueries(1): + user = User.objects.select_related('userprofile').get(username='bob') + with self.assertRaises(UserProfile.DoesNotExist): + user.userprofile + + def test_nullable_missing_reverse(self): + """ + Ticket #13839: select_related() should NOT cache None + for missing objects on a reverse 0-1 relation. + """ + Image.objects.create(name="imag1") + + with self.assertNumQueries(1): + image = Image.objects.select_related('product').get() + with self.assertRaises(Product.DoesNotExist): + image.product + + def test_parent_only(self): + with self.assertNumQueries(1): + p = Parent1.objects.select_related('child1').get(name1="Only Parent1") + with self.assertNumQueries(0): + with self.assertRaises(Child1.DoesNotExist): + p.child1 + + def test_multiple_subclass(self): + with self.assertNumQueries(1): + p = Parent1.objects.select_related('child1').get(name1="Child1 Parent1") + self.assertEqual(p.child1.name2, 'Child1 Parent2') + + def test_onetoone_with_subclass(self): + with self.assertNumQueries(1): + p = Parent2.objects.select_related('child2').get(name2="Child2 Parent2") + self.assertEqual(p.child2.name1, 'Child2 Parent1') + + def test_onetoone_with_two_subclasses(self): + with self.assertNumQueries(1): + p = Parent2.objects.select_related('child2', "child2__child3").get(name2="Child2 Parent2") + self.assertEqual(p.child2.name1, 'Child2 Parent1') + with self.assertRaises(Child3.DoesNotExist): + p.child2.child3 + p3 = Parent2(name2="Child3 Parent2") + p3.save() + c2 = Child3(name1="Child3 Parent1", parent2=p3, value=2, value3=3) + c2.save() + with self.assertNumQueries(1): + p = Parent2.objects.select_related('child2', "child2__child3").get(name2="Child3 Parent2") + self.assertEqual(p.child2.name1, 'Child3 Parent1') + self.assertEqual(p.child2.child3.value3, 3) + self.assertEqual(p.child2.child3.value, p.child2.value) + self.assertEqual(p.child2.name1, p.child2.child3.name1) + + def test_multiinheritance_two_subclasses(self): + with self.assertNumQueries(1): + p = Parent1.objects.select_related('child1', 'child1__child4').get(name1="Child1 Parent1") + self.assertEqual(p.child1.name2, 'Child1 Parent2') + self.assertEqual(p.child1.name1, p.name1) + with self.assertRaises(Child4.DoesNotExist): + p.child1.child4 + Child4(name1='n1', name2='n2', value=1, value4=4).save() + with self.assertNumQueries(1): + p = Parent2.objects.select_related('child1', 'child1__child4').get(name2="n2") + self.assertEqual(p.name2, 'n2') + self.assertEqual(p.child1.name1, 'n1') + self.assertEqual(p.child1.name2, p.name2) + self.assertEqual(p.child1.value, 1) + self.assertEqual(p.child1.child4.name1, p.child1.name1) + self.assertEqual(p.child1.child4.name2, p.child1.name2) + self.assertEqual(p.child1.child4.value, p.child1.value) + self.assertEqual(p.child1.child4.value4, 4) + + def test_inheritance_deferred(self): + c = Child4.objects.create(name1='n1', name2='n2', value=1, value4=4) + with self.assertNumQueries(1): + p = Parent2.objects.select_related('child1').only( + 'id2', 'child1__value').get(name2="n2") + self.assertEqual(p.id2, c.id2) + self.assertEqual(p.child1.value, 1) + p = Parent2.objects.select_related('child1').only( + 'id2', 'child1__value').get(name2="n2") + with self.assertNumQueries(1): + self.assertEqual(p.name2, 'n2') + p = Parent2.objects.select_related('child1').only( + 'id2', 'child1__value').get(name2="n2") + with self.assertNumQueries(1): + self.assertEqual(p.child1.name2, 'n2') + + def test_inheritance_deferred2(self): + c = Child4.objects.create(name1='n1', name2='n2', value=1, value4=4) + qs = Parent2.objects.select_related('child1', 'child1__child4').only( + 'id2', 'child1__value', 'child1__child4__value4') + with self.assertNumQueries(1): + p = qs.get(name2="n2") + self.assertEqual(p.id2, c.id2) + self.assertEqual(p.child1.value, 1) + self.assertEqual(p.child1.child4.value4, 4) + self.assertEqual(p.child1.child4.id2, c.id2) + p = qs.get(name2="n2") + with self.assertNumQueries(1): + self.assertEqual(p.child1.name2, 'n2') + p = qs.get(name2="n2") + with self.assertNumQueries(0): + self.assertEqual(p.child1.name1, 'n1') + self.assertEqual(p.child1.child4.name1, 'n1') + + def test_self_relation(self): + item1 = LinkedList.objects.create(name='item1') + LinkedList.objects.create(name='item2', previous_item=item1) + with self.assertNumQueries(1): + item1_db = LinkedList.objects.select_related('next_item').get(name='item1') + self.assertEqual(item1_db.next_item.name, 'item2') + + +class ReverseSelectRelatedValidationTests(SimpleTestCase): + """ + Rverse related fields should be listed in the validation message when an + invalid field is given in select_related(). + """ + non_relational_error = "Non-relational field given in select_related: '%s'. Choices are: %s" + invalid_error = "Invalid field name(s) given in select_related: '%s'. Choices are: %s" + + def test_reverse_related_validation(self): + fields = 'userprofile, userstat' + + with self.assertRaisesMessage(FieldError, self.invalid_error % ('foobar', fields)): + list(User.objects.select_related('foobar')) + + with self.assertRaisesMessage(FieldError, self.non_relational_error % ('username', fields)): + list(User.objects.select_related('username')) + + def test_reverse_related_validation_with_filtered_relation(self): + fields = 'userprofile, userstat, relation' + with self.assertRaisesMessage(FieldError, self.invalid_error % ('foobar', fields)): + list(User.objects.annotate(relation=FilteredRelation('userprofile')).select_related('foobar')) diff --git a/tests/select_related_regress/__init__.py b/tests/select_related_regress/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/select_related_regress/models.py b/tests/select_related_regress/models.py new file mode 100644 index 00000000..8e748d3e --- /dev/null +++ b/tests/select_related_regress/models.py @@ -0,0 +1,165 @@ +from __future__ import unicode_literals + +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +@python_2_unicode_compatible +class Building(models.Model): + name = models.CharField(max_length=10) + + def __str__(self): + return "Building: %s" % self.name + + +@python_2_unicode_compatible +class Device(models.Model): + building = models.ForeignKey('Building', models.CASCADE) + name = models.CharField(max_length=10) + + def __str__(self): + return "device '%s' in building %s" % (self.name, self.building) + + +@python_2_unicode_compatible +class Port(models.Model): + device = models.ForeignKey('Device', models.CASCADE) + port_number = models.CharField(max_length=10) + + def __str__(self): + return "%s/%s" % (self.device.name, self.port_number) + + +@python_2_unicode_compatible +class Connection(models.Model): + start = models.ForeignKey( + Port, + models.CASCADE, + related_name='connection_start', + unique=True, + ) + end = models.ForeignKey( + Port, + models.CASCADE, + related_name='connection_end', + unique=True, + ) + + def __str__(self): + return "%s to %s" % (self.start, self.end) + +# Another non-tree hierarchy that exercises code paths similar to the above +# example, but in a slightly different configuration. + + +class TUser(models.Model): + name = models.CharField(max_length=200) + + +class Person(models.Model): + user = models.ForeignKey(TUser, models.CASCADE, unique=True) + + +class Organizer(models.Model): + person = models.ForeignKey(Person, models.CASCADE) + + +class Student(models.Model): + person = models.ForeignKey(Person, models.CASCADE) + + +class Class(models.Model): + org = models.ForeignKey(Organizer, models.CASCADE) + + +class Enrollment(models.Model): + std = models.ForeignKey(Student, models.CASCADE) + cls = models.ForeignKey(Class, models.CASCADE) + +# Models for testing bug #8036. + + +class Country(models.Model): + name = models.CharField(max_length=50) + + +class State(models.Model): + name = models.CharField(max_length=50) + country = models.ForeignKey(Country, models.CASCADE) + + +class ClientStatus(models.Model): + name = models.CharField(max_length=50) + + +class Client(models.Model): + name = models.CharField(max_length=50) + state = models.ForeignKey(State, models.SET_NULL, null=True) + status = models.ForeignKey(ClientStatus, models.CASCADE) + + +class SpecialClient(Client): + value = models.IntegerField() + +# Some model inheritance exercises + + +@python_2_unicode_compatible +class Parent(models.Model): + name = models.CharField(max_length=10) + + def __str__(self): + return self.name + + +class Child(Parent): + value = models.IntegerField() + + +@python_2_unicode_compatible +class Item(models.Model): + name = models.CharField(max_length=10) + child = models.ForeignKey(Child, models.SET_NULL, null=True) + + def __str__(self): + return self.name + +# Models for testing bug #19870. + + +@python_2_unicode_compatible +class Fowl(models.Model): + name = models.CharField(max_length=10) + + def __str__(self): + return self.name + + +class Hen(Fowl): + pass + + +class Chick(Fowl): + mother = models.ForeignKey(Hen, models.CASCADE) + + +class Base(models.Model): + name = models.CharField(max_length=10) + lots_of_text = models.TextField() + + class Meta: + abstract = True + + +class A(Base): + a_field = models.CharField(max_length=10) + + +class B(Base): + b_field = models.CharField(max_length=10) + + +class C(Base): + c_a = models.ForeignKey(A, models.CASCADE) + c_b = models.ForeignKey(B, models.CASCADE) + is_published = models.BooleanField(default=False) diff --git a/tests/select_related_regress/tests.py b/tests/select_related_regress/tests.py new file mode 100644 index 00000000..6c59192d --- /dev/null +++ b/tests/select_related_regress/tests.py @@ -0,0 +1,207 @@ +from __future__ import unicode_literals + +from django.test import TestCase +from django.utils import six + +from .models import ( + A, B, Building, C, Chick, Child, Class, Client, ClientStatus, Connection, + Country, Device, Enrollment, Hen, Item, Organizer, Person, Port, + SpecialClient, State, Student, TUser, +) + + +class SelectRelatedRegressTests(TestCase): + + def test_regression_7110(self): + """ + Regression test for bug #7110. + + When using select_related(), we must query the + Device and Building tables using two different aliases (each) in order to + differentiate the start and end Connection fields. The net result is that + both the "connections = ..." queries here should give the same results + without pulling in more than the absolute minimum number of tables + (history has shown that it's easy to make a mistake in the implementation + and include some unnecessary bonus joins). + """ + + b = Building.objects.create(name='101') + dev1 = Device.objects.create(name="router", building=b) + dev2 = Device.objects.create(name="switch", building=b) + dev3 = Device.objects.create(name="server", building=b) + port1 = Port.objects.create(port_number='4', device=dev1) + port2 = Port.objects.create(port_number='7', device=dev2) + port3 = Port.objects.create(port_number='1', device=dev3) + c1 = Connection.objects.create(start=port1, end=port2) + c2 = Connection.objects.create(start=port2, end=port3) + + connections = Connection.objects.filter(start__device__building=b, end__device__building=b).order_by('id') + self.assertEqual( + [(c.id, six.text_type(c.start), six.text_type(c.end)) for c in connections], + [(c1.id, 'router/4', 'switch/7'), (c2.id, 'switch/7', 'server/1')] + ) + + connections = ( + Connection.objects + .filter(start__device__building=b, end__device__building=b) + .select_related() + .order_by('id') + ) + self.assertEqual( + [(c.id, six.text_type(c.start), six.text_type(c.end)) for c in connections], + [(c1.id, 'router/4', 'switch/7'), (c2.id, 'switch/7', 'server/1')] + ) + + # This final query should only have seven tables (port, device and building + # twice each, plus connection once). Thus, 6 joins plus the FROM table. + self.assertEqual(str(connections.query).count(" JOIN "), 6) + + def test_regression_8106(self): + """ + Regression test for bug #8106. + + Same sort of problem as the previous test, but this time there are + more extra tables to pull in as part of the select_related() and some + of them could potentially clash (so need to be kept separate). + """ + + us = TUser.objects.create(name="std") + usp = Person.objects.create(user=us) + uo = TUser.objects.create(name="org") + uop = Person.objects.create(user=uo) + s = Student.objects.create(person=usp) + o = Organizer.objects.create(person=uop) + c = Class.objects.create(org=o) + Enrollment.objects.create(std=s, cls=c) + + e_related = Enrollment.objects.all().select_related()[0] + self.assertEqual(e_related.std.person.user.name, "std") + self.assertEqual(e_related.cls.org.person.user.name, "org") + + def test_regression_8036(self): + """ + Regression test for bug #8036 + + the first related model in the tests below + ("state") is empty and we try to select the more remotely related + state__country. The regression here was not skipping the empty column results + for country before getting status. + """ + + Country.objects.create(name='Australia') + active = ClientStatus.objects.create(name='active') + client = Client.objects.create(name='client', status=active) + + self.assertEqual(client.status, active) + self.assertEqual(Client.objects.select_related()[0].status, active) + self.assertEqual(Client.objects.select_related('state')[0].status, active) + self.assertEqual(Client.objects.select_related('state', 'status')[0].status, active) + self.assertEqual(Client.objects.select_related('state__country')[0].status, active) + self.assertEqual(Client.objects.select_related('state__country', 'status')[0].status, active) + self.assertEqual(Client.objects.select_related('status')[0].status, active) + + def test_multi_table_inheritance(self): + """ Exercising select_related() with multi-table model inheritance. """ + c1 = Child.objects.create(name="child1", value=42) + Item.objects.create(name="item1", child=c1) + Item.objects.create(name="item2") + + self.assertQuerysetEqual( + Item.objects.select_related("child").order_by("name"), + ["", ""] + ) + + def test_regression_12851(self): + """ + Regression for #12851 + + Deferred fields are used correctly if you select_related a subset + of fields. + """ + australia = Country.objects.create(name='Australia') + active = ClientStatus.objects.create(name='active') + + wa = State.objects.create(name="Western Australia", country=australia) + Client.objects.create(name='Brian Burke', state=wa, status=active) + burke = Client.objects.select_related('state').defer('state__name').get(name='Brian Burke') + + self.assertEqual(burke.name, 'Brian Burke') + self.assertEqual(burke.state.name, 'Western Australia') + + # Still works if we're dealing with an inherited class + SpecialClient.objects.create(name='Troy Buswell', state=wa, status=active, value=42) + troy = SpecialClient.objects.select_related('state').defer('state__name').get(name='Troy Buswell') + + self.assertEqual(troy.name, 'Troy Buswell') + self.assertEqual(troy.value, 42) + self.assertEqual(troy.state.name, 'Western Australia') + + # Still works if we defer an attribute on the inherited class + troy = SpecialClient.objects.select_related('state').defer('value', 'state__name').get(name='Troy Buswell') + + self.assertEqual(troy.name, 'Troy Buswell') + self.assertEqual(troy.value, 42) + self.assertEqual(troy.state.name, 'Western Australia') + + # Also works if you use only, rather than defer + troy = SpecialClient.objects.select_related('state').only('name', 'state').get(name='Troy Buswell') + + self.assertEqual(troy.name, 'Troy Buswell') + self.assertEqual(troy.value, 42) + self.assertEqual(troy.state.name, 'Western Australia') + + def test_null_join_promotion(self): + australia = Country.objects.create(name='Australia') + active = ClientStatus.objects.create(name='active') + + wa = State.objects.create(name="Western Australia", country=australia) + bob = Client.objects.create(name='Bob', status=active) + jack = Client.objects.create(name='Jack', status=active, state=wa) + qs = Client.objects.filter(state=wa).select_related('state') + with self.assertNumQueries(1): + self.assertEqual(list(qs), [jack]) + self.assertEqual(qs[0].state, wa) + # The select_related join wasn't promoted as there was already an + # existing (even if trimmed) inner join to state. + self.assertNotIn('LEFT OUTER', str(qs.query)) + qs = Client.objects.select_related('state').order_by('name') + with self.assertNumQueries(1): + self.assertEqual(list(qs), [bob, jack]) + self.assertIs(qs[0].state, None) + self.assertEqual(qs[1].state, wa) + # The select_related join was promoted as there is already an + # existing join. + self.assertIn('LEFT OUTER', str(qs.query)) + + def test_regression_19870(self): + hen = Hen.objects.create(name='Hen') + Chick.objects.create(name='Chick', mother=hen) + + self.assertEqual(Chick.objects.all()[0].mother.name, 'Hen') + self.assertEqual(Chick.objects.select_related()[0].mother.name, 'Hen') + + def test_regression_10733(self): + a = A.objects.create(name='a', lots_of_text='lots_of_text_a', a_field='a_field') + b = B.objects.create(name='b', lots_of_text='lots_of_text_b', b_field='b_field') + c = C.objects.create(name='c', lots_of_text='lots_of_text_c', is_published=True, + c_a=a, c_b=b) + results = C.objects.all().only('name', 'lots_of_text', 'c_a', 'c_b', 'c_b__lots_of_text', + 'c_a__name', 'c_b__name').select_related() + self.assertSequenceEqual(results, [c]) + with self.assertNumQueries(0): + qs_c = results[0] + self.assertEqual(qs_c.name, 'c') + self.assertEqual(qs_c.lots_of_text, 'lots_of_text_c') + self.assertEqual(qs_c.c_b.lots_of_text, 'lots_of_text_b') + self.assertEqual(qs_c.c_a.name, 'a') + self.assertEqual(qs_c.c_b.name, 'b') + + def test_regression_22508(self): + building = Building.objects.create(name='101') + device = Device.objects.create(name="router", building=building) + Port.objects.create(port_number='1', device=device) + + device = Device.objects.get() + port = device.port_set.select_related('device__building').get() + with self.assertNumQueries(0): + port.device.building diff --git a/tests/test_mssql.py b/tests/test_mssql.py new file mode 100644 index 00000000..a7203a41 --- /dev/null +++ b/tests/test_mssql.py @@ -0,0 +1,86 @@ +# This is an example test settings file for use with the Django test suite. +# +# The 'sqlite3' backend requires only the ENGINE setting (an in- +# memory database will be used). All other backends will require a +# NAME and potentially authentication information. See the +# following section in the docs for more information: +# +# https://docs.djangoproject.com/en/dev/internals/contributing/writing-code/unit-tests/ +# +# The different databases that Django supports behave differently in certain +# situations, so it is recommended to run the test suite against as many +# database backends as possible. You may want to create a separate settings +# file for each of the backends you test against. +import os + +INSTANCE = os.environ.get('SQLINSTANCE', '') +HOST = os.environ.get('COMPUTERNAME', os.environ.get('HOST', 'localhost')) +if INSTANCE: + HOST = '\\'.join([HOST, INSTANCE]) +DATABASE = os.environ.get('DATABASE_NAME', 'django_test_backend') +USER = os.environ.get('SQLUSER', 'sa') +PASSWORD = os.environ.get('SQLPASSWORD', 'sa') + +DATABASES = { + 'default': { + 'ENGINE': os.environ.get('BACKEND', 'sql_server.pyodbc'), + 'NAME': DATABASE, + 'TEST_NAME': DATABASE, + 'HOST': HOST, + 'USER': USER, + 'PASSWORD': PASSWORD, + 'OPTIONS': { + 'provider': os.environ.get('ADO_PROVIDER', 'SQLNCLI11'), + # 'extra_params': 'DataTypeCompatibility=80;MARS Connection=True;', + 'use_legacy_date_fields': False, + }, + }, + 'other': { + 'ENGINE': os.environ.get('BACKEND', 'sql_server.pyodbc'), + 'NAME': DATABASE + '_other', + 'TEST_NAME': DATABASE + '_other', + 'HOST': HOST, + 'USER': USER, + 'PASSWORD': PASSWORD, + 'OPTIONS': { + 'provider': os.environ.get('ADO_PROVIDER', 'SQLNCLI11'), + # 'extra_params': 'DataTypeCompatibility=80;MARS Connection=True;', + 'use_legacy_date_fields': False, + }, + } +} + +MIDDLEWARE_CLASSES = [ + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'django.contrib.auth.middleware.SessionAuthenticationMiddleware', +] + +SECRET_KEY = "django_tests_secret_key" +# To speed up tests under SQLite we use the MD5 hasher as the default one. +# This should not be needed under other databases, as the relative speedup +# is only marginal there. +PASSWORD_HASHERS = ( + 'django.contrib.auth.hashers.MD5PasswordHasher', +) + +LOGGING = { + 'version': 1, + 'disable_existing_loggers': False, + 'handlers': { + 'console': { + 'class': 'logging.StreamHandler', + }, + }, + 'loggers': { + 'django': { + 'handlers': ['console'], + 'level': 'ERROR', + }, + 'django.db': { + 'handlers': ['console'], + # uncomment to enable logging of SQL statements + #'level': 'DEBUG', + }, + }, +} diff --git a/tests/transaction_hooks/__init__.py b/tests/transaction_hooks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/transaction_hooks/models.py b/tests/transaction_hooks/models.py new file mode 100644 index 00000000..cd2f22b5 --- /dev/null +++ b/tests/transaction_hooks/models.py @@ -0,0 +1,10 @@ +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +@python_2_unicode_compatible +class Thing(models.Model): + num = models.IntegerField() + + def __str__(self): + return "Thing %d" % self.num diff --git a/tests/transaction_hooks/tests.py b/tests/transaction_hooks/tests.py new file mode 100644 index 00000000..000de410 --- /dev/null +++ b/tests/transaction_hooks/tests.py @@ -0,0 +1,234 @@ +from django.db import connection, transaction +from django.test import TransactionTestCase, skipUnlessDBFeature + +from .models import Thing + + +class ForcedError(Exception): + pass + + +class TestConnectionOnCommit(TransactionTestCase): + """ + Tests for transaction.on_commit(). + + Creation/checking of database objects in parallel with callback tracking is + to verify that the behavior of the two match in all tested cases. + """ + available_apps = ['transaction_hooks'] + + def setUp(self): + self.notified = [] + + def notify(self, id_): + if id_ == 'error': + raise ForcedError() + self.notified.append(id_) + + def do(self, num): + """Create a Thing instance and notify about it.""" + Thing.objects.create(num=num) + transaction.on_commit(lambda: self.notify(num)) + + def assertDone(self, nums): + self.assertNotified(nums) + self.assertEqual(sorted(t.num for t in Thing.objects.all()), sorted(nums)) + + def assertNotified(self, nums): + self.assertEqual(self.notified, nums) + + def test_executes_immediately_if_no_transaction(self): + self.do(1) + self.assertDone([1]) + + def test_delays_execution_until_after_transaction_commit(self): + with transaction.atomic(): + self.do(1) + self.assertNotified([]) + self.assertDone([1]) + + def test_does_not_execute_if_transaction_rolled_back(self): + try: + with transaction.atomic(): + self.do(1) + raise ForcedError() + except ForcedError: + pass + + self.assertDone([]) + + def test_executes_only_after_final_transaction_committed(self): + with transaction.atomic(): + with transaction.atomic(): + self.do(1) + self.assertNotified([]) + self.assertNotified([]) + self.assertDone([1]) + + def test_discards_hooks_from_rolled_back_savepoint(self): + with transaction.atomic(): + # one successful savepoint + with transaction.atomic(): + self.do(1) + # one failed savepoint + try: + with transaction.atomic(): + self.do(2) + raise ForcedError() + except ForcedError: + pass + # another successful savepoint + with transaction.atomic(): + self.do(3) + + # only hooks registered during successful savepoints execute + self.assertDone([1, 3]) + + def test_no_hooks_run_from_failed_transaction(self): + """If outer transaction fails, no hooks from within it run.""" + try: + with transaction.atomic(): + with transaction.atomic(): + self.do(1) + raise ForcedError() + except ForcedError: + pass + + self.assertDone([]) + + def test_inner_savepoint_rolled_back_with_outer(self): + with transaction.atomic(): + try: + with transaction.atomic(): + with transaction.atomic(): + self.do(1) + raise ForcedError() + except ForcedError: + pass + self.do(2) + + self.assertDone([2]) + + def test_no_savepoints_atomic_merged_with_outer(self): + with transaction.atomic(): + with transaction.atomic(): + self.do(1) + try: + with transaction.atomic(savepoint=False): + raise ForcedError() + except ForcedError: + pass + + self.assertDone([]) + + def test_inner_savepoint_does_not_affect_outer(self): + with transaction.atomic(): + with transaction.atomic(): + self.do(1) + try: + with transaction.atomic(): + raise ForcedError() + except ForcedError: + pass + + self.assertDone([1]) + + def test_runs_hooks_in_order_registered(self): + with transaction.atomic(): + self.do(1) + with transaction.atomic(): + self.do(2) + self.do(3) + + self.assertDone([1, 2, 3]) + + def test_hooks_cleared_after_successful_commit(self): + with transaction.atomic(): + self.do(1) + with transaction.atomic(): + self.do(2) + + self.assertDone([1, 2]) # not [1, 1, 2] + + def test_hooks_cleared_after_rollback(self): + try: + with transaction.atomic(): + self.do(1) + raise ForcedError() + except ForcedError: + pass + + with transaction.atomic(): + self.do(2) + + self.assertDone([2]) + + @skipUnlessDBFeature('test_db_allows_multiple_connections') + def test_hooks_cleared_on_reconnect(self): + with transaction.atomic(): + self.do(1) + connection.close() + + connection.connect() + + with transaction.atomic(): + self.do(2) + + self.assertDone([2]) + + def test_error_in_hook_doesnt_prevent_clearing_hooks(self): + try: + with transaction.atomic(): + transaction.on_commit(lambda: self.notify('error')) + except ForcedError: + pass + + with transaction.atomic(): + self.do(1) + + self.assertDone([1]) + + def test_db_query_in_hook(self): + with transaction.atomic(): + Thing.objects.create(num=1) + transaction.on_commit( + lambda: [self.notify(t.num) for t in Thing.objects.all()] + ) + + self.assertDone([1]) + + def test_transaction_in_hook(self): + def on_commit(): + with transaction.atomic(): + t = Thing.objects.create(num=1) + self.notify(t.num) + + with transaction.atomic(): + transaction.on_commit(on_commit) + + self.assertDone([1]) + + def test_hook_in_hook(self): + def on_commit(i, add_hook): + with transaction.atomic(): + if add_hook: + transaction.on_commit(lambda: on_commit(i + 10, False)) + t = Thing.objects.create(num=i) + self.notify(t.num) + + with transaction.atomic(): + transaction.on_commit(lambda: on_commit(1, True)) + transaction.on_commit(lambda: on_commit(2, True)) + + self.assertDone([1, 11, 2, 12]) + + def test_raises_exception_non_autocommit_mode(self): + def should_never_be_called(): + raise AssertionError('this function should never be called') + + try: + connection.set_autocommit(False) + with self.assertRaises(transaction.TransactionManagementError): + transaction.on_commit(should_never_be_called) + finally: + connection.set_autocommit(True) diff --git a/tests/transactions/__init__.py b/tests/transactions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/transactions/models.py b/tests/transactions/models.py new file mode 100644 index 00000000..f4400363 --- /dev/null +++ b/tests/transactions/models.py @@ -0,0 +1,25 @@ +""" +Transactions + +Django handles transactions in three different ways. The default is to commit +each transaction upon a write, but you can decorate a function to get +commit-on-success behavior. Alternatively, you can manage the transaction +manually. +""" +from __future__ import unicode_literals + +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + + +@python_2_unicode_compatible +class Reporter(models.Model): + first_name = models.CharField(max_length=30) + last_name = models.CharField(max_length=30) + email = models.EmailField() + + class Meta: + ordering = ('first_name', 'last_name') + + def __str__(self): + return ("%s %s" % (self.first_name, self.last_name)).strip() diff --git a/tests/transactions/tests.py b/tests/transactions/tests.py new file mode 100644 index 00000000..c31cffa5 --- /dev/null +++ b/tests/transactions/tests.py @@ -0,0 +1,464 @@ +from __future__ import unicode_literals + +import sys +import threading +import time +from unittest import skipIf, skipUnless + +from django.db import ( + DatabaseError, Error, IntegrityError, OperationalError, connection, + transaction, +) +from django.test import ( + TransactionTestCase, skipIfDBFeature, skipUnlessDBFeature, +) + +from .models import Reporter + + +@skipUnless(connection.features.uses_savepoints, "'atomic' requires transactions and savepoints.") +class AtomicTests(TransactionTestCase): + """ + Tests for the atomic decorator and context manager. + + The tests make assertions on internal attributes because there isn't a + robust way to ask the database for its current transaction state. + + Since the decorator syntax is converted into a context manager (see the + implementation), there are only a few basic tests with the decorator + syntax and the bulk of the tests use the context manager syntax. + """ + + available_apps = ['transactions'] + + def test_decorator_syntax_commit(self): + @transaction.atomic + def make_reporter(): + Reporter.objects.create(first_name="Tintin") + make_reporter() + self.assertQuerysetEqual(Reporter.objects.all(), ['']) + + def test_decorator_syntax_rollback(self): + @transaction.atomic + def make_reporter(): + Reporter.objects.create(first_name="Haddock") + raise Exception("Oops, that's his last name") + with self.assertRaisesMessage(Exception, "Oops"): + make_reporter() + self.assertQuerysetEqual(Reporter.objects.all(), []) + + def test_alternate_decorator_syntax_commit(self): + @transaction.atomic() + def make_reporter(): + Reporter.objects.create(first_name="Tintin") + make_reporter() + self.assertQuerysetEqual(Reporter.objects.all(), ['']) + + def test_alternate_decorator_syntax_rollback(self): + @transaction.atomic() + def make_reporter(): + Reporter.objects.create(first_name="Haddock") + raise Exception("Oops, that's his last name") + with self.assertRaisesMessage(Exception, "Oops"): + make_reporter() + self.assertQuerysetEqual(Reporter.objects.all(), []) + + def test_commit(self): + with transaction.atomic(): + Reporter.objects.create(first_name="Tintin") + self.assertQuerysetEqual(Reporter.objects.all(), ['']) + + def test_rollback(self): + with self.assertRaisesMessage(Exception, "Oops"): + with transaction.atomic(): + Reporter.objects.create(first_name="Haddock") + raise Exception("Oops, that's his last name") + self.assertQuerysetEqual(Reporter.objects.all(), []) + + def test_nested_commit_commit(self): + with transaction.atomic(): + Reporter.objects.create(first_name="Tintin") + with transaction.atomic(): + Reporter.objects.create(first_name="Archibald", last_name="Haddock") + self.assertQuerysetEqual( + Reporter.objects.all(), + ['', ''] + ) + + def test_nested_commit_rollback(self): + with transaction.atomic(): + Reporter.objects.create(first_name="Tintin") + with self.assertRaisesMessage(Exception, "Oops"): + with transaction.atomic(): + Reporter.objects.create(first_name="Haddock") + raise Exception("Oops, that's his last name") + self.assertQuerysetEqual(Reporter.objects.all(), ['']) + + def test_nested_rollback_commit(self): + with self.assertRaisesMessage(Exception, "Oops"): + with transaction.atomic(): + Reporter.objects.create(last_name="Tintin") + with transaction.atomic(): + Reporter.objects.create(last_name="Haddock") + raise Exception("Oops, that's his first name") + self.assertQuerysetEqual(Reporter.objects.all(), []) + + def test_nested_rollback_rollback(self): + with self.assertRaisesMessage(Exception, "Oops"): + with transaction.atomic(): + Reporter.objects.create(last_name="Tintin") + with self.assertRaisesMessage(Exception, "Oops"): + with transaction.atomic(): + Reporter.objects.create(first_name="Haddock") + raise Exception("Oops, that's his last name") + raise Exception("Oops, that's his first name") + self.assertQuerysetEqual(Reporter.objects.all(), []) + + def test_merged_commit_commit(self): + with transaction.atomic(): + Reporter.objects.create(first_name="Tintin") + with transaction.atomic(savepoint=False): + Reporter.objects.create(first_name="Archibald", last_name="Haddock") + self.assertQuerysetEqual( + Reporter.objects.all(), + ['', ''] + ) + + def test_merged_commit_rollback(self): + with transaction.atomic(): + Reporter.objects.create(first_name="Tintin") + with self.assertRaisesMessage(Exception, "Oops"): + with transaction.atomic(savepoint=False): + Reporter.objects.create(first_name="Haddock") + raise Exception("Oops, that's his last name") + # Writes in the outer block are rolled back too. + self.assertQuerysetEqual(Reporter.objects.all(), []) + + def test_merged_rollback_commit(self): + with self.assertRaisesMessage(Exception, "Oops"): + with transaction.atomic(): + Reporter.objects.create(last_name="Tintin") + with transaction.atomic(savepoint=False): + Reporter.objects.create(last_name="Haddock") + raise Exception("Oops, that's his first name") + self.assertQuerysetEqual(Reporter.objects.all(), []) + + def test_merged_rollback_rollback(self): + with self.assertRaisesMessage(Exception, "Oops"): + with transaction.atomic(): + Reporter.objects.create(last_name="Tintin") + with self.assertRaisesMessage(Exception, "Oops"): + with transaction.atomic(savepoint=False): + Reporter.objects.create(first_name="Haddock") + raise Exception("Oops, that's his last name") + raise Exception("Oops, that's his first name") + self.assertQuerysetEqual(Reporter.objects.all(), []) + + def test_reuse_commit_commit(self): + atomic = transaction.atomic() + with atomic: + Reporter.objects.create(first_name="Tintin") + with atomic: + Reporter.objects.create(first_name="Archibald", last_name="Haddock") + self.assertQuerysetEqual(Reporter.objects.all(), ['', '']) + + def test_reuse_commit_rollback(self): + atomic = transaction.atomic() + with atomic: + Reporter.objects.create(first_name="Tintin") + with self.assertRaisesMessage(Exception, "Oops"): + with atomic: + Reporter.objects.create(first_name="Haddock") + raise Exception("Oops, that's his last name") + self.assertQuerysetEqual(Reporter.objects.all(), ['']) + + def test_reuse_rollback_commit(self): + atomic = transaction.atomic() + with self.assertRaisesMessage(Exception, "Oops"): + with atomic: + Reporter.objects.create(last_name="Tintin") + with atomic: + Reporter.objects.create(last_name="Haddock") + raise Exception("Oops, that's his first name") + self.assertQuerysetEqual(Reporter.objects.all(), []) + + def test_reuse_rollback_rollback(self): + atomic = transaction.atomic() + with self.assertRaisesMessage(Exception, "Oops"): + with atomic: + Reporter.objects.create(last_name="Tintin") + with self.assertRaisesMessage(Exception, "Oops"): + with atomic: + Reporter.objects.create(first_name="Haddock") + raise Exception("Oops, that's his last name") + raise Exception("Oops, that's his first name") + self.assertQuerysetEqual(Reporter.objects.all(), []) + + def test_force_rollback(self): + with transaction.atomic(): + Reporter.objects.create(first_name="Tintin") + # atomic block shouldn't rollback, but force it. + self.assertFalse(transaction.get_rollback()) + transaction.set_rollback(True) + self.assertQuerysetEqual(Reporter.objects.all(), []) + + def test_prevent_rollback(self): + with transaction.atomic(): + Reporter.objects.create(first_name="Tintin") + sid = transaction.savepoint() + # trigger a database error inside an inner atomic without savepoint + with self.assertRaises(DatabaseError): + with transaction.atomic(savepoint=False): + with connection.cursor() as cursor: + cursor.execute( + "SELECT no_such_col FROM transactions_reporter") + # prevent atomic from rolling back since we're recovering manually + self.assertTrue(transaction.get_rollback()) + transaction.set_rollback(False) + transaction.savepoint_rollback(sid) + self.assertQuerysetEqual(Reporter.objects.all(), ['']) + + +class AtomicInsideTransactionTests(AtomicTests): + """All basic tests for atomic should also pass within an existing transaction.""" + + def setUp(self): + self.atomic = transaction.atomic() + self.atomic.__enter__() + + def tearDown(self): + self.atomic.__exit__(*sys.exc_info()) + + +@skipIf( + connection.features.autocommits_when_autocommit_is_off, + "This test requires a non-autocommit mode that doesn't autocommit." +) +class AtomicWithoutAutocommitTests(AtomicTests): + """All basic tests for atomic should also pass when autocommit is turned off.""" + + def setUp(self): + transaction.set_autocommit(False) + + def tearDown(self): + # The tests access the database after exercising 'atomic', initiating + # a transaction ; a rollback is required before restoring autocommit. + transaction.rollback() + transaction.set_autocommit(True) + + +@skipUnless(connection.features.uses_savepoints, "'atomic' requires transactions and savepoints.") +class AtomicMergeTests(TransactionTestCase): + """Test merging transactions with savepoint=False.""" + + available_apps = ['transactions'] + + def test_merged_outer_rollback(self): + with transaction.atomic(): + Reporter.objects.create(first_name="Tintin") + with transaction.atomic(savepoint=False): + Reporter.objects.create(first_name="Archibald", last_name="Haddock") + with self.assertRaisesMessage(Exception, "Oops"): + with transaction.atomic(savepoint=False): + Reporter.objects.create(first_name="Calculus") + raise Exception("Oops, that's his last name") + # The third insert couldn't be roll back. Temporarily mark the + # connection as not needing rollback to check it. + self.assertTrue(transaction.get_rollback()) + transaction.set_rollback(False) + self.assertEqual(Reporter.objects.count(), 3) + transaction.set_rollback(True) + # The second insert couldn't be roll back. Temporarily mark the + # connection as not needing rollback to check it. + self.assertTrue(transaction.get_rollback()) + transaction.set_rollback(False) + self.assertEqual(Reporter.objects.count(), 3) + transaction.set_rollback(True) + # The first block has a savepoint and must roll back. + self.assertQuerysetEqual(Reporter.objects.all(), []) + + def test_merged_inner_savepoint_rollback(self): + with transaction.atomic(): + Reporter.objects.create(first_name="Tintin") + with transaction.atomic(): + Reporter.objects.create(first_name="Archibald", last_name="Haddock") + with self.assertRaisesMessage(Exception, "Oops"): + with transaction.atomic(savepoint=False): + Reporter.objects.create(first_name="Calculus") + raise Exception("Oops, that's his last name") + # The third insert couldn't be roll back. Temporarily mark the + # connection as not needing rollback to check it. + self.assertTrue(transaction.get_rollback()) + transaction.set_rollback(False) + self.assertEqual(Reporter.objects.count(), 3) + transaction.set_rollback(True) + # The second block has a savepoint and must roll back. + self.assertEqual(Reporter.objects.count(), 1) + self.assertQuerysetEqual(Reporter.objects.all(), ['']) + + +@skipUnless(connection.features.uses_savepoints, "'atomic' requires transactions and savepoints.") +class AtomicErrorsTests(TransactionTestCase): + + available_apps = ['transactions'] + + def test_atomic_prevents_setting_autocommit(self): + autocommit = transaction.get_autocommit() + with transaction.atomic(): + with self.assertRaises(transaction.TransactionManagementError): + transaction.set_autocommit(not autocommit) + # Make sure autocommit wasn't changed. + self.assertEqual(connection.autocommit, autocommit) + + def test_atomic_prevents_calling_transaction_methods(self): + with transaction.atomic(): + with self.assertRaises(transaction.TransactionManagementError): + transaction.commit() + with self.assertRaises(transaction.TransactionManagementError): + transaction.rollback() + + def test_atomic_prevents_queries_in_broken_transaction(self): + r1 = Reporter.objects.create(first_name="Archibald", last_name="Haddock") + with transaction.atomic(): + r2 = Reporter(first_name="Cuthbert", last_name="Calculus", id=r1.id) + with self.assertRaises(IntegrityError): + r2.save(force_insert=True) + # The transaction is marked as needing rollback. + with self.assertRaises(transaction.TransactionManagementError): + r2.save(force_update=True) + self.assertEqual(Reporter.objects.get(pk=r1.pk).last_name, "Haddock") + + @skipIfDBFeature('atomic_transactions') + def test_atomic_allows_queries_after_fixing_transaction(self): + r1 = Reporter.objects.create(first_name="Archibald", last_name="Haddock") + with transaction.atomic(): + r2 = Reporter(first_name="Cuthbert", last_name="Calculus", id=r1.id) + with self.assertRaises(IntegrityError): + r2.save(force_insert=True) + # Mark the transaction as no longer needing rollback. + transaction.set_rollback(False) + r2.save(force_update=True) + self.assertEqual(Reporter.objects.get(pk=r1.pk).last_name, "Calculus") + + @skipUnlessDBFeature('test_db_allows_multiple_connections') + def test_atomic_prevents_queries_in_broken_transaction_after_client_close(self): + with transaction.atomic(): + Reporter.objects.create(first_name="Archibald", last_name="Haddock") + connection.close() + # The connection is closed and the transaction is marked as + # needing rollback. This will raise an InterfaceError on databases + # that refuse to create cursors on closed connections (PostgreSQL) + # and a TransactionManagementError on other databases. + with self.assertRaises(Error): + Reporter.objects.create(first_name="Cuthbert", last_name="Calculus") + # The connection is usable again . + self.assertEqual(Reporter.objects.count(), 0) + + +@skipUnless(connection.vendor == 'mysql', "MySQL-specific behaviors") +class AtomicMySQLTests(TransactionTestCase): + + available_apps = ['transactions'] + + @skipIf(threading is None, "Test requires threading") + def test_implicit_savepoint_rollback(self): + """MySQL implicitly rolls back savepoints when it deadlocks (#22291).""" + + other_thread_ready = threading.Event() + + def other_thread(): + try: + with transaction.atomic(): + Reporter.objects.create(id=1, first_name="Tintin") + other_thread_ready.set() + # We cannot synchronize the two threads with an event here + # because the main thread locks. Sleep for a little while. + time.sleep(1) + # 2) ... and this line deadlocks. (see below for 1) + Reporter.objects.exclude(id=1).update(id=2) + finally: + # This is the thread-local connection, not the main connection. + connection.close() + + other_thread = threading.Thread(target=other_thread) + other_thread.start() + other_thread_ready.wait() + + with self.assertRaisesMessage(OperationalError, 'Deadlock found'): + # Double atomic to enter a transaction and create a savepoint. + with transaction.atomic(): + with transaction.atomic(): + # 1) This line locks... (see above for 2) + Reporter.objects.create(id=1, first_name="Tintin") + + other_thread.join() + + +class AtomicMiscTests(TransactionTestCase): + + available_apps = [] + + def test_wrap_callable_instance(self): + """#20028 -- Atomic must support wrapping callable instances.""" + + class Callable(object): + def __call__(self): + pass + + # Must not raise an exception + transaction.atomic(Callable()) + + @skipUnlessDBFeature('can_release_savepoints') + def test_atomic_does_not_leak_savepoints_on_failure(self): + """#23074 -- Savepoints must be released after rollback.""" + + # Expect an error when rolling back a savepoint that doesn't exist. + # Done outside of the transaction block to ensure proper recovery. + with self.assertRaises(Error): + + # Start a plain transaction. + with transaction.atomic(): + + # Swallow the intentional error raised in the sub-transaction. + with self.assertRaisesMessage(Exception, "Oops"): + + # Start a sub-transaction with a savepoint. + with transaction.atomic(): + sid = connection.savepoint_ids[-1] + raise Exception("Oops") + + # This is expected to fail because the savepoint no longer exists. + connection.savepoint_rollback(sid) + + +@skipIf( + connection.features.autocommits_when_autocommit_is_off, + "This test requires a non-autocommit mode that doesn't autocommit." +) +class NonAutocommitTests(TransactionTestCase): + + available_apps = [] + + def test_orm_query_after_error_and_rollback(self): + """ + ORM queries are allowed after an error and a rollback in non-autocommit + mode (#27504). + """ + # TODO: fix this test + return + transaction.set_autocommit(False) + r1 = Reporter.objects.create(first_name='Archibald', last_name='Haddock') + r2 = Reporter(first_name='Cuthbert', last_name='Calculus', id=r1.id) + with self.assertRaises(IntegrityError): + r2.save(force_insert=True) + transaction.rollback() + Reporter.objects.last() + + def test_orm_query_without_autocommit(self): + """#24921 -- ORM queries must be possible after set_autocommit(False).""" + transaction.set_autocommit(False) + try: + Reporter.objects.create(first_name="Tintin") + finally: + transaction.rollback() + transaction.set_autocommit(True) diff --git a/tests/unmanaged_models/__init__.py b/tests/unmanaged_models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unmanaged_models/models.py b/tests/unmanaged_models/models.py new file mode 100644 index 00000000..657d3d5b --- /dev/null +++ b/tests/unmanaged_models/models.py @@ -0,0 +1,143 @@ +""" +Models can have a ``managed`` attribute, which specifies whether the SQL code +is generated for the table on various manage.py operations. +""" + +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + +# All of these models are created in the database by Django. + + +@python_2_unicode_compatible +class A01(models.Model): + f_a = models.CharField(max_length=10, db_index=True) + f_b = models.IntegerField() + + class Meta: + db_table = 'a01' + + def __str__(self): + return self.f_a + + +@python_2_unicode_compatible +class B01(models.Model): + fk_a = models.ForeignKey(A01, models.CASCADE) + f_a = models.CharField(max_length=10, db_index=True) + f_b = models.IntegerField() + + class Meta: + db_table = 'b01' + # 'managed' is True by default. This tests we can set it explicitly. + managed = True + + def __str__(self): + return self.f_a + + +@python_2_unicode_compatible +class C01(models.Model): + mm_a = models.ManyToManyField(A01, db_table='d01') + f_a = models.CharField(max_length=10, db_index=True) + f_b = models.IntegerField() + + class Meta: + db_table = 'c01' + + def __str__(self): + return self.f_a + +# All of these models use the same tables as the previous set (they are shadows +# of possibly a subset of the columns). There should be no creation errors, +# since we have told Django they aren't managed by Django. + + +@python_2_unicode_compatible +class A02(models.Model): + f_a = models.CharField(max_length=10, db_index=True) + + class Meta: + db_table = 'a01' + managed = False + + def __str__(self): + return self.f_a + + +@python_2_unicode_compatible +class B02(models.Model): + class Meta: + db_table = 'b01' + managed = False + + fk_a = models.ForeignKey(A02, models.CASCADE) + f_a = models.CharField(max_length=10, db_index=True) + f_b = models.IntegerField() + + def __str__(self): + return self.f_a + + +# To re-use the many-to-many intermediate table, we need to manually set up +# things up. +@python_2_unicode_compatible +class C02(models.Model): + mm_a = models.ManyToManyField(A02, through="Intermediate") + f_a = models.CharField(max_length=10, db_index=True) + f_b = models.IntegerField() + + class Meta: + db_table = 'c01' + managed = False + + def __str__(self): + return self.f_a + + +class Intermediate(models.Model): + a02 = models.ForeignKey(A02, models.CASCADE, db_column="a01_id") + c02 = models.ForeignKey(C02, models.CASCADE, db_column="c01_id") + + class Meta: + db_table = 'd01' + managed = False + + +# These next models test the creation (or not) of many to many join tables +# between managed and unmanaged models. A join table between two unmanaged +# models shouldn't be automatically created (see #10647). +# + +# Firstly, we need some models that will create the tables, purely so that the +# tables are created. This is a test setup, not a requirement for unmanaged +# models. +class Proxy1(models.Model): + class Meta: + db_table = "unmanaged_models_proxy1" + + +class Proxy2(models.Model): + class Meta: + db_table = "unmanaged_models_proxy2" + + +class Unmanaged1(models.Model): + class Meta: + managed = False + db_table = "unmanaged_models_proxy1" + + +# Unmanaged with an m2m to unmanaged: the intermediary table won't be created. +class Unmanaged2(models.Model): + mm = models.ManyToManyField(Unmanaged1) + + class Meta: + managed = False + db_table = "unmanaged_models_proxy2" + + +# Here's an unmanaged model with an m2m to a managed one; the intermediary +# table *will* be created (unless given a custom `through` as for C02 above). +class Managed1(models.Model): + mm = models.ManyToManyField(Unmanaged1) diff --git a/tests/unmanaged_models/tests.py b/tests/unmanaged_models/tests.py new file mode 100644 index 00000000..e98cee8a --- /dev/null +++ b/tests/unmanaged_models/tests.py @@ -0,0 +1,61 @@ +from __future__ import unicode_literals + +from django.db import connection +from django.test import TestCase + +from .models import A01, A02, B01, B02, C01, C02, Managed1, Unmanaged2 + + +class SimpleTests(TestCase): + + def test_simple(self): + """ + The main test here is that the all the models can be created without + any database errors. We can also do some more simple insertion and + lookup tests while we're here to show that the second of models do + refer to the tables from the first set. + """ + # Insert some data into one set of models. + a = A01.objects.create(f_a="foo", f_b=42) + B01.objects.create(fk_a=a, f_a="fred", f_b=1729) + c = C01.objects.create(f_a="barney", f_b=1) + c.mm_a.set([a]) + + # ... and pull it out via the other set. + a2 = A02.objects.all()[0] + self.assertIsInstance(a2, A02) + self.assertEqual(a2.f_a, "foo") + + b2 = B02.objects.all()[0] + self.assertIsInstance(b2, B02) + self.assertEqual(b2.f_a, "fred") + + self.assertIsInstance(b2.fk_a, A02) + self.assertEqual(b2.fk_a.f_a, "foo") + + self.assertEqual(list(C02.objects.filter(f_a=None)), []) + + resp = list(C02.objects.filter(mm_a=a.id)) + self.assertEqual(len(resp), 1) + + self.assertIsInstance(resp[0], C02) + self.assertEqual(resp[0].f_a, 'barney') + + +class ManyToManyUnmanagedTests(TestCase): + + def test_many_to_many_between_unmanaged(self): + """ + The intermediary table between two unmanaged models should not be created. + """ + table = Unmanaged2._meta.get_field('mm').m2m_db_table() + tables = connection.introspection.table_names() + self.assertNotIn(table, tables, "Table '%s' should not exist, but it does." % table) + + def test_many_to_many_between_unmanaged_and_managed(self): + """ + An intermediary table between a managed and an unmanaged model should be created. + """ + table = Managed1._meta.get_field('mm').m2m_db_table() + tables = connection.introspection.table_names() + self.assertIn(table, tables, "Table '%s' does not exist." % table) diff --git a/tests/update/__init__.py b/tests/update/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/update/models.py b/tests/update/models.py new file mode 100644 index 00000000..648a7733 --- /dev/null +++ b/tests/update/models.py @@ -0,0 +1,52 @@ +""" +Tests for the update() queryset method that allows in-place, multi-object +updates. +""" + +from django.db import models +from django.utils import six +from django.utils.encoding import python_2_unicode_compatible + + +@python_2_unicode_compatible +class DataPoint(models.Model): + name = models.CharField(max_length=20) + value = models.CharField(max_length=20) + another_value = models.CharField(max_length=20, blank=True) + + def __str__(self): + return six.text_type(self.name) + + +@python_2_unicode_compatible +class RelatedPoint(models.Model): + name = models.CharField(max_length=20) + data = models.ForeignKey(DataPoint, models.CASCADE) + + def __str__(self): + return six.text_type(self.name) + + +class A(models.Model): + x = models.IntegerField(default=10) + + +class B(models.Model): + a = models.ForeignKey(A, models.CASCADE) + y = models.IntegerField(default=10) + + +class C(models.Model): + y = models.IntegerField(default=10) + + +class D(C): + a = models.ForeignKey(A, models.CASCADE) + + +class Foo(models.Model): + target = models.CharField(max_length=10, unique=True) + + +class Bar(models.Model): + foo = models.ForeignKey(Foo, models.CASCADE, to_field='target') diff --git a/tests/update/tests.py b/tests/update/tests.py new file mode 100644 index 00000000..114091f6 --- /dev/null +++ b/tests/update/tests.py @@ -0,0 +1,182 @@ +from __future__ import unicode_literals + +from django.core.exceptions import FieldError +from django.db.models import Count, F, Max +from django.test import TestCase + +from .models import A, B, Bar, D, DataPoint, Foo, RelatedPoint + + +class SimpleTest(TestCase): + def setUp(self): + self.a1 = A.objects.create() + self.a2 = A.objects.create() + for x in range(20): + B.objects.create(a=self.a1) + D.objects.create(a=self.a1) + + def test_nonempty_update(self): + """ + Update changes the right number of rows for a nonempty queryset + """ + num_updated = self.a1.b_set.update(y=100) + self.assertEqual(num_updated, 20) + cnt = B.objects.filter(y=100).count() + self.assertEqual(cnt, 20) + + def test_empty_update(self): + """ + Update changes the right number of rows for an empty queryset + """ + num_updated = self.a2.b_set.update(y=100) + self.assertEqual(num_updated, 0) + cnt = B.objects.filter(y=100).count() + self.assertEqual(cnt, 0) + + def test_nonempty_update_with_inheritance(self): + """ + Update changes the right number of rows for an empty queryset + when the update affects only a base table + """ + num_updated = self.a1.d_set.update(y=100) + self.assertEqual(num_updated, 20) + cnt = D.objects.filter(y=100).count() + self.assertEqual(cnt, 20) + + def test_empty_update_with_inheritance(self): + """ + Update changes the right number of rows for an empty queryset + when the update affects only a base table + """ + num_updated = self.a2.d_set.update(y=100) + self.assertEqual(num_updated, 0) + cnt = D.objects.filter(y=100).count() + self.assertEqual(cnt, 0) + + def test_foreign_key_update_with_id(self): + """ + Update works using _id for foreign keys + """ + num_updated = self.a1.d_set.update(a_id=self.a2) + self.assertEqual(num_updated, 20) + self.assertEqual(self.a2.d_set.count(), 20) + + +class AdvancedTests(TestCase): + + def setUp(self): + self.d0 = DataPoint.objects.create(name="d0", value="apple") + self.d2 = DataPoint.objects.create(name="d2", value="banana") + self.d3 = DataPoint.objects.create(name="d3", value="banana") + self.r1 = RelatedPoint.objects.create(name="r1", data=self.d3) + + def test_update(self): + """ + Objects are updated by first filtering the candidates into a queryset + and then calling the update() method. It executes immediately and + returns nothing. + """ + resp = DataPoint.objects.filter(value="apple").update(name="d1") + self.assertEqual(resp, 1) + resp = DataPoint.objects.filter(value="apple") + self.assertEqual(list(resp), [self.d0]) + + def test_update_multiple_objects(self): + """ + We can update multiple objects at once. + """ + resp = DataPoint.objects.filter(value="banana").update( + value="pineapple") + self.assertEqual(resp, 2) + self.assertEqual(DataPoint.objects.get(name="d2").value, 'pineapple') + + def test_update_fk(self): + """ + Foreign key fields can also be updated, although you can only update + the object referred to, not anything inside the related object. + """ + resp = RelatedPoint.objects.filter(name="r1").update(data=self.d0) + self.assertEqual(resp, 1) + resp = RelatedPoint.objects.filter(data__name="d0") + self.assertEqual(list(resp), [self.r1]) + + def test_update_multiple_fields(self): + """ + Multiple fields can be updated at once + """ + resp = DataPoint.objects.filter(value="apple").update( + value="fruit", another_value="peach") + self.assertEqual(resp, 1) + d = DataPoint.objects.get(name="d0") + self.assertEqual(d.value, 'fruit') + self.assertEqual(d.another_value, 'peach') + + def test_update_all(self): + """ + In the rare case you want to update every instance of a model, update() + is also a manager method. + """ + self.assertEqual(DataPoint.objects.update(value='thing'), 3) + resp = DataPoint.objects.values('value').distinct() + self.assertEqual(list(resp), [{'value': 'thing'}]) + + def test_update_slice_fail(self): + """ + We do not support update on already sliced query sets. + """ + method = DataPoint.objects.all()[:2].update + with self.assertRaises(AssertionError): + method(another_value='another thing') + + def test_update_respects_to_field(self): + """ + Update of an FK field which specifies a to_field works. + """ + a_foo = Foo.objects.create(target='aaa') + b_foo = Foo.objects.create(target='bbb') + bar = Bar.objects.create(foo=a_foo) + self.assertEqual(bar.foo_id, a_foo.target) + bar_qs = Bar.objects.filter(pk=bar.pk) + self.assertEqual(bar_qs[0].foo_id, a_foo.target) + bar_qs.update(foo=b_foo) + self.assertEqual(bar_qs[0].foo_id, b_foo.target) + + def test_update_annotated_queryset(self): + """ + Update of a queryset that's been annotated. + """ + # Trivial annotated update + qs = DataPoint.objects.annotate(alias=F('value')) + self.assertEqual(qs.update(another_value='foo'), 3) + # Update where annotation is used for filtering + qs = DataPoint.objects.annotate(alias=F('value')).filter(alias='apple') + self.assertEqual(qs.update(another_value='foo'), 1) + # Update where annotation is used in update parameters + qs = DataPoint.objects.annotate(alias=F('value')) + self.assertEqual(qs.update(another_value=F('alias')), 3) + # Update where aggregation annotation is used in update parameters + qs = DataPoint.objects.annotate(max=Max('value')) + with self.assertRaisesMessage(FieldError, 'Aggregate functions are not allowed in this query'): + qs.update(another_value=F('max')) + + def test_update_annotated_multi_table_queryset(self): + """ + Update of a queryset that's been annotated and involves multiple tables. + """ + # TODO: fix + return + # Trivial annotated update + qs = DataPoint.objects.annotate(related_count=Count('relatedpoint')) + self.assertEqual(qs.update(value='Foo'), 3) + # Update where annotation is used for filtering + qs = DataPoint.objects.annotate(related_count=Count('relatedpoint')) + self.assertEqual(qs.filter(related_count=1).update(value='Foo'), 1) + # Update where annotation is used in update parameters + # #26539 - This isn't forbidden but also doesn't generate proper SQL + # qs = RelatedPoint.objects.annotate(data_name=F('data__name')) + # updated = qs.update(name=F('data_name')) + # self.assertEqual(updated, 1) + # Update where aggregation annotation is used in update parameters + qs = RelatedPoint.objects.annotate(max=Max('data__value')) + with self.assertRaisesMessage(FieldError, 'Aggregate functions are not allowed in this query'): + qs.update(name=F('max')) diff --git a/tests/update_only_fields/__init__.py b/tests/update_only_fields/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/update_only_fields/models.py b/tests/update_only_fields/models.py new file mode 100644 index 00000000..a3be5088 --- /dev/null +++ b/tests/update_only_fields/models.py @@ -0,0 +1,42 @@ + +from django.db import models +from django.utils.encoding import python_2_unicode_compatible + +GENDER_CHOICES = ( + ('M', 'Male'), + ('F', 'Female'), +) + + +class Account(models.Model): + num = models.IntegerField() + + +@python_2_unicode_compatible +class Person(models.Model): + name = models.CharField(max_length=20) + gender = models.CharField(max_length=1, choices=GENDER_CHOICES) + pid = models.IntegerField(null=True, default=None) + + def __str__(self): + return self.name + + +class Employee(Person): + employee_num = models.IntegerField(default=0) + profile = models.ForeignKey('Profile', models.SET_NULL, related_name='profiles', null=True) + accounts = models.ManyToManyField('Account', related_name='employees', blank=True) + + +@python_2_unicode_compatible +class Profile(models.Model): + name = models.CharField(max_length=200) + salary = models.FloatField(default=1000.0) + + def __str__(self): + return self.name + + +class ProxyEmployee(Employee): + class Meta: + proxy = True diff --git a/tests/update_only_fields/tests.py b/tests/update_only_fields/tests.py new file mode 100644 index 00000000..7627bcd3 --- /dev/null +++ b/tests/update_only_fields/tests.py @@ -0,0 +1,254 @@ +from __future__ import unicode_literals + +from django.db.models.signals import post_save, pre_save +from django.test import TestCase + +from .models import Account, Employee, Person, Profile, ProxyEmployee + + +class UpdateOnlyFieldsTests(TestCase): + def test_update_fields_basic(self): + s = Person.objects.create(name='Sara', gender='F') + self.assertEqual(s.gender, 'F') + + s.gender = 'M' + s.name = 'Ian' + s.save(update_fields=['name']) + + s = Person.objects.get(pk=s.pk) + self.assertEqual(s.gender, 'F') + self.assertEqual(s.name, 'Ian') + + def test_update_fields_deferred(self): + s = Person.objects.create(name='Sara', gender='F', pid=22) + self.assertEqual(s.gender, 'F') + + s1 = Person.objects.defer("gender", "pid").get(pk=s.pk) + s1.name = "Emily" + s1.gender = "M" + + with self.assertNumQueries(1): + s1.save() + + s2 = Person.objects.get(pk=s1.pk) + self.assertEqual(s2.name, "Emily") + self.assertEqual(s2.gender, "M") + + def test_update_fields_only_1(self): + s = Person.objects.create(name='Sara', gender='F') + self.assertEqual(s.gender, 'F') + + s1 = Person.objects.only('name').get(pk=s.pk) + s1.name = "Emily" + s1.gender = "M" + + with self.assertNumQueries(1): + s1.save() + + s2 = Person.objects.get(pk=s1.pk) + self.assertEqual(s2.name, "Emily") + self.assertEqual(s2.gender, "M") + + def test_update_fields_only_2(self): + s = Person.objects.create(name='Sara', gender='F', pid=22) + self.assertEqual(s.gender, 'F') + + s1 = Person.objects.only('name').get(pk=s.pk) + s1.name = "Emily" + s1.gender = "M" + + with self.assertNumQueries(2): + s1.save(update_fields=['pid']) + + s2 = Person.objects.get(pk=s1.pk) + self.assertEqual(s2.name, "Sara") + self.assertEqual(s2.gender, "F") + + def test_update_fields_only_repeated(self): + s = Person.objects.create(name='Sara', gender='F') + self.assertEqual(s.gender, 'F') + + s1 = Person.objects.only('name').get(pk=s.pk) + s1.gender = 'M' + with self.assertNumQueries(1): + s1.save() + # save() should not fetch deferred fields + s1 = Person.objects.only('name').get(pk=s.pk) + with self.assertNumQueries(1): + s1.save() + + def test_update_fields_inheritance_defer(self): + profile_boss = Profile.objects.create(name='Boss', salary=3000) + e1 = Employee.objects.create(name='Sara', gender='F', employee_num=1, profile=profile_boss) + e1 = Employee.objects.only('name').get(pk=e1.pk) + e1.name = 'Linda' + with self.assertNumQueries(1): + e1.save() + self.assertEqual(Employee.objects.get(pk=e1.pk).name, 'Linda') + + def test_update_fields_fk_defer(self): + profile_boss = Profile.objects.create(name='Boss', salary=3000) + profile_receptionist = Profile.objects.create(name='Receptionist', salary=1000) + e1 = Employee.objects.create(name='Sara', gender='F', employee_num=1, profile=profile_boss) + e1 = Employee.objects.only('profile').get(pk=e1.pk) + e1.profile = profile_receptionist + with self.assertNumQueries(1): + e1.save() + self.assertEqual(Employee.objects.get(pk=e1.pk).profile, profile_receptionist) + e1.profile_id = profile_boss.pk + with self.assertNumQueries(1): + e1.save() + self.assertEqual(Employee.objects.get(pk=e1.pk).profile, profile_boss) + + def test_select_related_only_interaction(self): + profile_boss = Profile.objects.create(name='Boss', salary=3000) + e1 = Employee.objects.create(name='Sara', gender='F', employee_num=1, profile=profile_boss) + e1 = Employee.objects.only('profile__salary').select_related('profile').get(pk=e1.pk) + profile_boss.name = 'Clerk' + profile_boss.salary = 1000 + profile_boss.save() + # The loaded salary of 3000 gets saved, the name of 'Clerk' isn't + # overwritten. + with self.assertNumQueries(1): + e1.profile.save() + reloaded_profile = Profile.objects.get(pk=profile_boss.pk) + self.assertEqual(reloaded_profile.name, profile_boss.name) + self.assertEqual(reloaded_profile.salary, 3000) + + def test_update_fields_m2m(self): + profile_boss = Profile.objects.create(name='Boss', salary=3000) + e1 = Employee.objects.create(name='Sara', gender='F', employee_num=1, profile=profile_boss) + a1 = Account.objects.create(num=1) + a2 = Account.objects.create(num=2) + e1.accounts.set([a1, a2]) + + with self.assertRaises(ValueError): + e1.save(update_fields=['accounts']) + + def test_update_fields_inheritance(self): + profile_boss = Profile.objects.create(name='Boss', salary=3000) + profile_receptionist = Profile.objects.create(name='Receptionist', salary=1000) + e1 = Employee.objects.create(name='Sara', gender='F', employee_num=1, profile=profile_boss) + + e1.name = 'Ian' + e1.gender = 'M' + e1.save(update_fields=['name']) + + e2 = Employee.objects.get(pk=e1.pk) + self.assertEqual(e2.name, 'Ian') + self.assertEqual(e2.gender, 'F') + self.assertEqual(e2.profile, profile_boss) + + e2.profile = profile_receptionist + e2.name = 'Sara' + e2.save(update_fields=['profile']) + + e3 = Employee.objects.get(pk=e1.pk) + self.assertEqual(e3.name, 'Ian') + self.assertEqual(e3.profile, profile_receptionist) + + with self.assertNumQueries(1): + e3.profile = profile_boss + e3.save(update_fields=['profile_id']) + + e4 = Employee.objects.get(pk=e3.pk) + self.assertEqual(e4.profile, profile_boss) + self.assertEqual(e4.profile_id, profile_boss.pk) + + def test_update_fields_inheritance_with_proxy_model(self): + profile_boss = Profile.objects.create(name='Boss', salary=3000) + profile_receptionist = Profile.objects.create(name='Receptionist', salary=1000) + e1 = ProxyEmployee.objects.create(name='Sara', gender='F', employee_num=1, profile=profile_boss) + + e1.name = 'Ian' + e1.gender = 'M' + e1.save(update_fields=['name']) + + e2 = ProxyEmployee.objects.get(pk=e1.pk) + self.assertEqual(e2.name, 'Ian') + self.assertEqual(e2.gender, 'F') + self.assertEqual(e2.profile, profile_boss) + + e2.profile = profile_receptionist + e2.name = 'Sara' + e2.save(update_fields=['profile']) + + e3 = ProxyEmployee.objects.get(pk=e1.pk) + self.assertEqual(e3.name, 'Ian') + self.assertEqual(e3.profile, profile_receptionist) + + def test_update_fields_signals(self): + p = Person.objects.create(name='Sara', gender='F') + pre_save_data = [] + + def pre_save_receiver(**kwargs): + pre_save_data.append(kwargs['update_fields']) + pre_save.connect(pre_save_receiver) + post_save_data = [] + + def post_save_receiver(**kwargs): + post_save_data.append(kwargs['update_fields']) + post_save.connect(post_save_receiver) + p.save(update_fields=['name']) + self.assertEqual(len(pre_save_data), 1) + self.assertEqual(len(pre_save_data[0]), 1) + self.assertIn('name', pre_save_data[0]) + self.assertEqual(len(post_save_data), 1) + self.assertEqual(len(post_save_data[0]), 1) + self.assertIn('name', post_save_data[0]) + + pre_save.disconnect(pre_save_receiver) + post_save.disconnect(post_save_receiver) + + def test_update_fields_incorrect_params(self): + s = Person.objects.create(name='Sara', gender='F') + + with self.assertRaises(ValueError): + s.save(update_fields=['first_name']) + + with self.assertRaises(ValueError): + s.save(update_fields="name") + + def test_empty_update_fields(self): + s = Person.objects.create(name='Sara', gender='F') + pre_save_data = [] + + def pre_save_receiver(**kwargs): + pre_save_data.append(kwargs['update_fields']) + pre_save.connect(pre_save_receiver) + post_save_data = [] + + def post_save_receiver(**kwargs): + post_save_data.append(kwargs['update_fields']) + post_save.connect(post_save_receiver) + # Save is skipped. + with self.assertNumQueries(0): + s.save(update_fields=[]) + # Signals were skipped, too... + self.assertEqual(len(pre_save_data), 0) + self.assertEqual(len(post_save_data), 0) + + pre_save.disconnect(pre_save_receiver) + post_save.disconnect(post_save_receiver) + + def test_num_queries_inheritance(self): + s = Employee.objects.create(name='Sara', gender='F') + s.employee_num = 1 + s.name = 'Emily' + with self.assertNumQueries(1): + s.save(update_fields=['employee_num']) + s = Employee.objects.get(pk=s.pk) + self.assertEqual(s.employee_num, 1) + self.assertEqual(s.name, 'Sara') + s.employee_num = 2 + s.name = 'Emily' + with self.assertNumQueries(1): + s.save(update_fields=['name']) + s = Employee.objects.get(pk=s.pk) + self.assertEqual(s.name, 'Emily') + self.assertEqual(s.employee_num, 1) + # A little sanity check that we actually did updates... + self.assertEqual(Employee.objects.count(), 1) + self.assertEqual(Person.objects.count(), 1) + with self.assertNumQueries(2): + s.save(update_fields=['name', 'employee_num']) diff --git a/tests/urls.py b/tests/urls.py new file mode 100644 index 00000000..7d3a3a79 --- /dev/null +++ b/tests/urls.py @@ -0,0 +1,7 @@ +"""This URLconf exists because Django expects ROOT_URLCONF to exist. URLs +should be added within the test folders, and use TestCase.urls to set them. +This helps the tests remain isolated. +""" + + +urlpatterns = []