diff --git a/tests/raw_query/models.py b/tests/raw_query/models.py index c2fd72047b..ba3dd56519 100644 --- a/tests/raw_query/models.py +++ b/tests/raw_query/models.py @@ -1,4 +1,5 @@ from django_mongodb.fields import ObjectIdAutoField +from django_mongodb.managers import MongoManager from django.db import models @@ -8,6 +9,8 @@ class Author(models.Model): last_name = models.CharField(max_length=255) dob = models.DateField() + objects = MongoManager() + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Protect against annotations being passed to __init__ -- @@ -25,25 +28,35 @@ class Book(models.Model): paperback = models.BooleanField(default=False) opening_line = models.TextField() + objects = MongoManager() + class BookFkAsPk(models.Model): book = models.ForeignKey( Book, models.CASCADE, primary_key=True, db_column="not_the_default" ) + objects = MongoManager() + class Coffee(models.Model): brand = models.CharField(max_length=255, db_column="name") price = models.DecimalField(max_digits=10, decimal_places=2, default=0) + objects = MongoManager() + class MixedCaseIDColumn(models.Model): id = ObjectIdAutoField(primary_key=True, db_column="MiXeD_CaSe_Id") + objects = MongoManager() + class Reviewer(models.Model): reviewed = models.ManyToManyField(Book) + objects = MongoManager() + class FriendlyAuthor(Author): pass diff --git a/tests/raw_query/tests.py b/tests/raw_query/tests.py index 1dcc7ce740..f515704284 100644 --- a/tests/raw_query/tests.py +++ b/tests/raw_query/tests.py @@ -1,9 +1,13 @@ +""" +These tests are adapted from Django's tests/raw_query/tests.py. +""" + from datetime import date -from decimal import Decimal + +from django_mongodb.queryset import RawQuerySet from django.core.exceptions import FieldDoesNotExist -from django.db.models.query import RawQuerySet -from django.test import TestCase, skipUnlessDBFeature +from django.test import TestCase from .models import ( Author, @@ -16,7 +20,7 @@ ) -class RawQueryTests(TestCase): +class RawAggregateTests(TestCase): @classmethod def setUpTestData(cls): cls.a1 = Author.objects.create( @@ -74,22 +78,18 @@ def assertSuccessfulRawQuery( query, expected_results, expected_annotations=(), - params=[], - translations=None, ): """ - Execute the passed query against the passed model and check the output + Execute the passed query against the passed model and check the output. """ - results = list( - model.objects.raw(query, params=params, translations=translations) - ) - self.assertProcessed(model, results, expected_results, expected_annotations) - self.assertAnnotations(results, expected_annotations) + results = list(model.objects.raw_aggregate(query)) + expected_results = list(expected_results) + with self.assertNumQueries(0): + self.assertProcessed(model, results, expected_results, expected_annotations) + self.assertAnnotations(results, expected_annotations) def assertProcessed(self, model, results, orig, expected_annotations=()): - """ - Compare the results of a raw query against expected results - """ + """Compare the results of a raw query against expected results.""" self.assertEqual(len(results), len(orig)) for index, item in enumerate(results): orig_item = orig[index] @@ -97,26 +97,22 @@ def assertProcessed(self, model, results, orig, expected_annotations=()): setattr(orig_item, *annotation) for field in model._meta.fields: - # All values on the model are equal + # All values on the model are equal. self.assertEqual( getattr(item, field.attname), getattr(orig_item, field.attname) ) - # This includes checking that they are the same type + # This includes checking that they are the same type. self.assertEqual( type(getattr(item, field.attname)), type(getattr(orig_item, field.attname)), ) def assertNoAnnotations(self, results): - """ - The results of a raw query contain no annotations - """ + """The results of a raw query contain no annotations.""" self.assertAnnotations(results, ()) def assertAnnotations(self, results, expected_annotations): - """ - The passed raw query results contain the expected annotations - """ + """The passed raw query results contain the expected annotations.""" if expected_annotations: for index, result in enumerate(results): annotation, value = expected_annotations[index] @@ -124,19 +120,13 @@ def assertAnnotations(self, results, expected_annotations): self.assertEqual(getattr(result, annotation), value) def test_rawqueryset_repr(self): - queryset = RawQuerySet(raw_query="SELECT * FROM raw_query_author") - self.assertEqual( - repr(queryset), "" - ) - self.assertEqual( - repr(queryset.query), "" - ) + queryset = RawQuerySet(pipeline=[]) + self.assertEqual(repr(queryset), "") + self.assertEqual(repr(queryset.query), "") def test_simple_raw_query(self): - """ - Basic test of raw query with a simple database query - """ - query = "SELECT * FROM raw_query_author" + """Basic test of raw query with a simple database query.""" + query = [] authors = Author.objects.all() self.assertSuccessfulRawQuery(Author, query, authors) @@ -145,16 +135,16 @@ def test_raw_query_lazy(self): Raw queries are lazy: they aren't actually executed until they're iterated over. """ - q = Author.objects.raw("SELECT * FROM raw_query_author") + q = Author.objects.raw_aggregate([]) self.assertIsNone(q.query.cursor) list(q) self.assertIsNotNone(q.query.cursor) def test_FK_raw_query(self): """ - Test of a simple raw query against a model containing a foreign key + Test of a simple raw query against a model containing a foreign key. """ - query = "SELECT * FROM raw_query_book" + query = [] books = Book.objects.all() self.assertSuccessfulRawQuery(Book, query, books) @@ -163,7 +153,7 @@ def test_db_column_handler(self): Test of a simple raw query against a model containing a field with db_column defined. """ - query = "SELECT * FROM raw_query_coffee" + query = [] coffees = Coffee.objects.all() self.assertSuccessfulRawQuery(Coffee, query, coffees) @@ -171,157 +161,85 @@ def test_pk_with_mixed_case_db_column(self): """ A raw query with a model that has a pk db_column with mixed case. """ - query = "SELECT * FROM raw_query_mixedcaseidcolumn" + query = [] queryset = MixedCaseIDColumn.objects.all() self.assertSuccessfulRawQuery(MixedCaseIDColumn, query, queryset) def test_order_handler(self): """ Test of raw raw query's tolerance for columns being returned in any - order + order. """ selects = ( ("dob, last_name, first_name, id"), ("last_name, dob, first_name, id"), ("first_name, last_name, dob, id"), ) - for select in selects: - query = "SELECT %s FROM raw_query_author" % select + select = {col: 1 for col in select.split(", ")} + query = [{"$project": select}] authors = Author.objects.all() self.assertSuccessfulRawQuery(Author, query, authors) - def test_translations(self): - """ - Test of raw query's optional ability to translate unexpected result - column names to specific model fields - """ - query = ( - "SELECT first_name AS first, last_name AS last, dob, id " - "FROM raw_query_author" - ) - translations = {"first": "first_name", "last": "last_name"} - authors = Author.objects.all() - self.assertSuccessfulRawQuery(Author, query, authors, translations=translations) - - def test_params(self): - """ - Test passing optional query parameters - """ - query = "SELECT * FROM raw_query_author WHERE first_name = %s" - author = Author.objects.all()[2] - params = [author.first_name] - qset = Author.objects.raw(query, params=params) - results = list(qset) - self.assertProcessed(Author, results, [author]) - self.assertNoAnnotations(results) - self.assertEqual(len(results), 1) - self.assertIsInstance(repr(qset), str) - - def test_params_none(self): - query = "SELECT * FROM raw_query_author WHERE first_name like 'J%'" - qset = Author.objects.raw(query, params=None) - self.assertEqual(len(qset), 2) - - def test_escaped_percent(self): - query = "SELECT * FROM raw_query_author WHERE first_name like 'J%%'" - qset = Author.objects.raw(query) - self.assertEqual(len(qset), 2) - - @skipUnlessDBFeature("supports_paramstyle_pyformat") - def test_pyformat_params(self): - """ - Test passing optional query parameters - """ - query = "SELECT * FROM raw_query_author WHERE first_name = %(first)s" - author = Author.objects.all()[2] - params = {"first": author.first_name} - qset = Author.objects.raw(query, params=params) - results = list(qset) - self.assertProcessed(Author, results, [author]) - self.assertNoAnnotations(results) - self.assertEqual(len(results), 1) - self.assertIsInstance(repr(qset), str) - def test_query_representation(self): - """ - Test representation of raw query with parameters - """ - query = "SELECT * FROM raw_query_author WHERE last_name = %(last)s" - qset = Author.objects.raw(query, {"last": "foo"}) - self.assertEqual( - repr(qset), - "", - ) - self.assertEqual( - repr(qset.query), - "", - ) - - query = "SELECT * FROM raw_query_author WHERE last_name = %s" - qset = Author.objects.raw(query, {"foo"}) + """Test representation of raw query.""" + query = [{"$match": {"last_name": "foo"}}] + qset = Author.objects.raw_aggregate(query) self.assertEqual( repr(qset), - "", + "", ) self.assertEqual( repr(qset.query), - "", + "", ) def test_many_to_many(self): """ - Test of a simple raw query against a model containing a m2m field + Test of a simple raw query against a model containing a m2m field. """ - query = "SELECT * FROM raw_query_reviewer" + query = [] reviewers = Reviewer.objects.all() self.assertSuccessfulRawQuery(Reviewer, query, reviewers) - def test_extra_conversions(self): - """Extra translations are ignored.""" - query = "SELECT * FROM raw_query_author" - translations = {"something": "else"} - authors = Author.objects.all() - self.assertSuccessfulRawQuery(Author, query, authors, translations=translations) - def test_missing_fields(self): - query = "SELECT id, first_name, dob FROM raw_query_author" - for author in Author.objects.raw(query): + query = [{"$project": {"id": 1, "first_name": 1, "dob": 1}}] + for author in Author.objects.raw_aggregate(query): self.assertIsNotNone(author.first_name) - # last_name isn't given, but it will be retrieved on demand + # last_name isn't given, but it will be retrieved on demand. self.assertIsNotNone(author.last_name) def test_missing_fields_without_PK(self): - query = "SELECT first_name, dob FROM raw_query_author" + query = [{"$project": {"first_name": 1, "dob": 1, "_id": 0}}] msg = "Raw query must include the primary key" with self.assertRaisesMessage(FieldDoesNotExist, msg): - list(Author.objects.raw(query)) + list(Author.objects.raw_aggregate(query)) def test_annotations(self): - query = ( - "SELECT a.*, count(b.id) as book_count " - "FROM raw_query_author a " - "LEFT JOIN raw_query_book b ON a.id = b.author_id " - "GROUP BY a.id, a.first_name, a.last_name, a.dob ORDER BY a.id" - ) + query = [ + { + "$project": { + "first_name": 1, + "last_name": 1, + "dob": 1, + "birth_year": {"$year": "$dob"}, + }, + }, + {"$sort": {"_id": 1}}, + ] expected_annotations = ( - ("book_count", 3), - ("book_count", 0), - ("book_count", 1), - ("book_count", 0), + ("birth_year", 1950), + ("birth_year", 1920), + ("birth_year", 1986), + ("birth_year", 1932), ) authors = Author.objects.order_by("pk") self.assertSuccessfulRawQuery(Author, query, authors, expected_annotations) - def test_white_space_query(self): - query = " SELECT * FROM raw_query_author" - authors = Author.objects.all() - self.assertSuccessfulRawQuery(Author, query, authors) - def test_multiple_iterations(self): - query = "SELECT * FROM raw_query_author" + query = [] normal_authors = Author.objects.all() - raw_authors = Author.objects.raw(query) + raw_authors = Author.objects.raw_aggregate(query) # First Iteration first_iterations = 0 @@ -339,33 +257,32 @@ def test_multiple_iterations(self): def test_get_item(self): # Indexing on RawQuerySets - query = "SELECT * FROM raw_query_author ORDER BY id ASC" - third_author = Author.objects.raw(query)[2] + query = [{"$sort": {"id": 1}}] + third_author = Author.objects.raw_aggregate(query)[2] self.assertEqual(third_author.first_name, "Bob") - first_two = Author.objects.raw(query)[0:2] + first_two = Author.objects.raw_aggregate(query)[0:2] self.assertEqual(len(first_two), 2) with self.assertRaises(TypeError): - Author.objects.raw(query)["test"] + Author.objects.raw_aggregate(query)["test"] def test_inheritance(self): f = FriendlyAuthor.objects.create( first_name="Wesley", last_name="Chun", dob=date(1962, 10, 28) ) - query = "SELECT * FROM raw_query_friendlyauthor" - self.assertEqual([o.pk for o in FriendlyAuthor.objects.raw(query)], [f.pk]) + query = [] + self.assertEqual( + [o.pk for o in FriendlyAuthor.objects.raw_aggregate(query)], [f.pk] + ) def test_query_count(self): - self.assertNumQueries( - 1, list, Author.objects.raw("SELECT * FROM raw_query_author") - ) + self.assertNumQueries(1, list, Author.objects.raw_aggregate([])) def test_subquery_in_raw_sql(self): list( - Book.objects.raw( - "SELECT id FROM " - "(SELECT * FROM raw_query_book WHERE paperback IS NOT NULL) sq" + Book.objects.raw_aggregate( + [{"$match": {"paperback": {"$ne": None}}}, {"$project": {"id": 1}}] ) ) @@ -380,40 +297,29 @@ def test_db_column_name_is_used_in_raw_query(self): b = BookFkAsPk.objects.create(book=self.b1) self.assertEqual( list( - BookFkAsPk.objects.raw( - "SELECT not_the_default FROM raw_query_bookfkaspk" + BookFkAsPk.objects.raw_aggregate( + [{"$project": {"not_the_default": 1, "_id": 0}}] ) ), [b], ) - def test_decimal_parameter(self): - c = Coffee.objects.create(brand="starbucks", price=20.5) - qs = Coffee.objects.raw( - "SELECT * FROM raw_query_coffee WHERE price >= %s", params=[Decimal(20)] - ) - self.assertEqual(list(qs), [c]) - def test_result_caching(self): with self.assertNumQueries(1): - books = Book.objects.raw("SELECT * FROM raw_query_book") + books = Book.objects.raw_aggregate([]) list(books) list(books) def test_iterator(self): with self.assertNumQueries(2): - books = Book.objects.raw("SELECT * FROM raw_query_book") + books = Book.objects.raw_aggregate([]) list(books.iterator()) list(books.iterator()) def test_bool(self): - self.assertIs(bool(Book.objects.raw("SELECT * FROM raw_query_book")), True) - self.assertIs( - bool(Book.objects.raw("SELECT * FROM raw_query_book WHERE id = 0")), False - ) + self.assertIs(bool(Book.objects.raw_aggregate([])), True) + self.assertIs(bool(Book.objects.raw_aggregate([{"$match": {"id": 0}}])), False) def test_len(self): - self.assertEqual(len(Book.objects.raw("SELECT * FROM raw_query_book")), 4) - self.assertEqual( - len(Book.objects.raw("SELECT * FROM raw_query_book WHERE id = 0")), 0 - ) + self.assertEqual(len(Book.objects.raw_aggregate([])), 4) + self.assertEqual(len(Book.objects.raw_aggregate([{"$match": {"id": 0}}])), 0)