diff --git a/.github/workflows/runtests.py b/.github/workflows/runtests.py index 3d01657d7..62610b28c 100755 --- a/.github/workflows/runtests.py +++ b/.github/workflows/runtests.py @@ -120,6 +120,7 @@ "queries", "queries_", "queryset_pickle", + "raw_query_", "redirects_tests", "reserved_names", "reverse_lookup", diff --git a/django_mongodb/managers.py b/django_mongodb/managers.py new file mode 100644 index 000000000..055c84402 --- /dev/null +++ b/django_mongodb/managers.py @@ -0,0 +1,7 @@ +from django.db.models.manager import BaseManager + +from .queryset import MongoQuerySet + + +class MongoManager(BaseManager.from_queryset(MongoQuerySet)): + pass diff --git a/django_mongodb/queryset.py b/django_mongodb/queryset.py new file mode 100644 index 000000000..c8b351e31 --- /dev/null +++ b/django_mongodb/queryset.py @@ -0,0 +1,96 @@ +from itertools import chain + +from django.core.exceptions import FieldDoesNotExist +from django.db import connections +from django.db.models import QuerySet +from django.db.models.query import RawModelIterable as BaseRawModelIterable +from django.db.models.query import RawQuerySet as BaseRawQuerySet +from django.db.models.sql.query import RawQuery as BaseRawQuery + + +class MongoQuerySet(QuerySet): + def raw_aggregate(self, pipeline, using=None): + return RawQuerySet(pipeline, model=self.model, using=using) + + +class RawQuerySet(BaseRawQuerySet): + def __init__(self, pipeline, model=None, using=None): + super().__init__(pipeline, model=model, using=using) + self.query = RawQuery(pipeline, using=self.db, model=self.model) + # Override the superclass's columns property which relies on PEP 249's + # cursor.description. Instead, RawModelIterable will set the columns + # based on the keys in the first result. + self.columns = None + + def iterator(self): + yield from RawModelIterable(self) + + +class RawQuery(BaseRawQuery): + def __init__(self, pipeline, using, model): + self.pipeline = pipeline + super().__init__(sql=None, using=using) + self.model = model + + def _execute_query(self): + connection = connections[self.using] + collection = connection.get_collection(self.model._meta.db_table) + self.cursor = collection.aggregate(self.pipeline) + + def __str__(self): + return str(self.pipeline) + + +class RawModelIterable(BaseRawModelIterable): + def __iter__(self): + """ + This is copied from the superclass except for the part that sets + self.queryset.columns from the first result. + """ + db = self.queryset.db + query = self.queryset.query + connection = connections[db] + compiler = connection.ops.compiler("SQLCompiler")(query, connection, db) + query_iterator = iter(query) + try: + # Get the columns from the first result. + try: + first_result = next(query_iterator) + except StopIteration: + # No results. + return + self.queryset.columns = list(first_result.keys()) + # Reset the iterator to include the first item. + query_iterator = self._make_result(chain([first_result], query_iterator)) + ( + model_init_names, + model_init_pos, + annotation_fields, + ) = self.queryset.resolve_model_init_order() + model_cls = self.queryset.model + if model_cls._meta.pk.attname not in model_init_names: + raise FieldDoesNotExist("Raw query must include the primary key") + fields = [self.queryset.model_fields.get(c) for c in self.queryset.columns] + converters = compiler.get_converters( + [f.get_col(f.model._meta.db_table) if f else None for f in fields] + ) + if converters: + query_iterator = compiler.apply_converters(query_iterator, converters) + for values in query_iterator: + # Associate fields to values + model_init_values = [values[pos] for pos in model_init_pos] + instance = model_cls.from_db(db, model_init_names, model_init_values) + if annotation_fields: + for column, pos in annotation_fields: + setattr(instance, column, values[pos]) + yield instance + finally: + query.cursor.close() + + def _make_result(self, query): + """ + Convert documents (dictionaries) to tuples as expected by the rest + of __iter__(). + """ + for result in query: + yield tuple(result.values()) diff --git a/docs/source/conf.py b/docs/source/conf.py index b8c7862d8..564f763a6 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -17,11 +17,25 @@ # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration -extensions = [] +# If true, the current module name will be prepended to all description +# unit titles (such as .. function::). +add_module_names = False + +extensions = [ + "sphinx.ext.intersphinx", +] # templates_path = ["_templates"] exclude_patterns = [] +intersphinx_mapping = { + "django": ( + "https://docs.djangoproject.com/en/5.0/", + "http://docs.djangoproject.com/en/5.0/_objects/", + ), + "pymongo": ("https://pymongo.readthedocs.io/en/stable/", None), + "python": ("https://docs.python.org/3/", None), +} # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output diff --git a/docs/source/index.rst b/docs/source/index.rst index dbbcd0b15..e9a2ace6d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,16 +1,11 @@ -.. django_mongodb documentation master file, created by - sphinx-quickstart on Mon Apr 15 12:38:26 2024. - You can adapt this file completely to your liking, but it should at least - contain the root ``toctree`` directive. - -Welcome to django_mongodb's documentation! -========================================== +django-mongodb 5.0.x documentation +================================== .. toctree:: - :maxdepth: 2 + :maxdepth: 1 :caption: Contents: - + querysets Indices and tables ================== diff --git a/docs/source/querysets.rst b/docs/source/querysets.rst new file mode 100644 index 000000000..fbdde9338 --- /dev/null +++ b/docs/source/querysets.rst @@ -0,0 +1,69 @@ +``QuerySet`` API reference +========================== + +Some MongoDB-specific ``QuerySet`` methods are available by adding a custom +:class:`~django.db.models.Manager`, ``MongoManager``, to your model:: + + from django.db import models + + from django_mongodb.managers import MongoManager + + + class MyModel(models.Model): + ... + + objects = MongoManager() + + +.. currentmodule:: django_mongodb.queryset.MongoQuerySet + +``raw_aggregate()`` +------------------- + +.. method:: raw_aggregate(pipeline, using=None) + +Similar to :meth:`QuerySet.raw()`, but +instead of a raw SQL query, this method accepts a pipeline that will be passed +to :meth:`pymongo.collection.Collection.aggregate`. + +For example, you could write a custom match criteria:: + + Question.objects.raw_aggregate([{"$match": {"question_text": "What's up"}}]) + +The pipeline may also return additional fields that will be added as +annotations on the models:: + + >>> questions = Question.objects.raw_aggregate([{ + ... "$project": { + ... "question_text": 1, + ... "pub_date": 1, + ... "year_published": {"$year": "$pub_date"} + ... } + ... }]) + >>> for q in questions: + ... print(f"{q.question_text} was published in {q.year_published}.") + ... + What's up? was published in 2024. + +Fields may also be left out: + + >>> Question.objects.raw_aggregate([{"$project": {"question_text": 1}}]) + +The ``Question`` objects returned by this query will be deferred model instances +(see :meth:`~django.db.models.query.QuerySet.defer()`). This means that the +fields that are omitted from the query will be loaded on demand. For example:: + + >>> for q in Question.objects.raw_aggregate([{"$project": {"question_text": 1}}]): + >>> print( + ... q.question_text, # This will be retrieved by the original query. + ... q.pub_date, # This will be retrieved on demand. + ... ) + ... + What's new 2023-09-03 12:00:00+00:00 + What's up 2024-08-23 20:57:30+00:00 + +From outward appearances, this looks like the query has retrieved both the +question text and published date. However, this example actually issued three +queries. Only the question texts were retrieved by the ``raw_aggregate()`` +query -- the published dates were both retrieved on demand when they were +printed. diff --git a/tests/raw_query_/__init__.py b/tests/raw_query_/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/raw_query_/models.py b/tests/raw_query_/models.py new file mode 100644 index 000000000..aab01de6c --- /dev/null +++ b/tests/raw_query_/models.py @@ -0,0 +1,60 @@ +from django.db import models + +from django_mongodb.fields import ObjectIdAutoField +from django_mongodb.managers import MongoManager + + +class Author(models.Model): + first_name = models.CharField(max_length=255) + last_name = models.CharField(max_length=255) + dob = models.DateField() + + objects = MongoManager() + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Protect against annotations being passed to __init__ -- + # this'll make the test suite get angry if annotations aren't + # treated differently than fields. + for k in kwargs: + assert k in [f.attname for f in self._meta.fields], ( + "Author.__init__ got an unexpected parameter: %s" % k + ) + + +class Book(models.Model): + title = models.CharField(max_length=255) + author = models.ForeignKey(Author, models.CASCADE) + paperback = models.BooleanField(default=False) + opening_line = models.TextField() + + objects = MongoManager() + + +class BookFkAsPk(models.Model): + book = models.ForeignKey(Book, models.CASCADE, primary_key=True, db_column="not_the_default") + + objects = MongoManager() + + +class Coffee(models.Model): + brand = models.CharField(max_length=255, db_column="name") + price = models.DecimalField(max_digits=10, decimal_places=2, default=0) + + objects = MongoManager() + + +class MixedCaseIDColumn(models.Model): + id = ObjectIdAutoField(primary_key=True, db_column="MiXeD_CaSe_Id") + + objects = MongoManager() + + +class Reviewer(models.Model): + reviewed = models.ManyToManyField(Book) + + objects = MongoManager() + + +class FriendlyAuthor(Author): + pass diff --git a/tests/raw_query_/test_raw_aggregate.py b/tests/raw_query_/test_raw_aggregate.py new file mode 100644 index 000000000..dd2fd48b4 --- /dev/null +++ b/tests/raw_query_/test_raw_aggregate.py @@ -0,0 +1,310 @@ +""" +These tests are adapted from Django's tests/raw_query/tests.py. +""" + +from datetime import date + +from django.core.exceptions import FieldDoesNotExist +from django.test import TestCase + +from django_mongodb.queryset import RawQuerySet + +from .models import ( + Author, + Book, + BookFkAsPk, + Coffee, + FriendlyAuthor, + MixedCaseIDColumn, + Reviewer, +) + + +class RawAggregateTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.a1 = Author.objects.create(first_name="Joe", last_name="Smith", dob=date(1950, 9, 20)) + cls.a2 = Author.objects.create(first_name="Jill", last_name="Doe", dob=date(1920, 4, 2)) + cls.a3 = Author.objects.create(first_name="Bob", last_name="Smith", dob=date(1986, 1, 25)) + cls.a4 = Author.objects.create(first_name="Bill", last_name="Jones", dob=date(1932, 5, 10)) + cls.b1 = Book.objects.create( + title="The awesome book", + author=cls.a1, + paperback=False, + opening_line=( + "It was a bright cold day in April and the clocks were striking " "thirteen." + ), + ) + cls.b2 = Book.objects.create( + title="The horrible book", + author=cls.a1, + paperback=True, + opening_line=( + "On an evening in the latter part of May a middle-aged man " + "was walking homeward from Shaston to the village of Marlott, " + "in the adjoining Vale of Blakemore, or Blackmoor." + ), + ) + cls.b3 = Book.objects.create( + title="Another awesome book", + author=cls.a1, + paperback=False, + opening_line="A squat gray building of only thirty-four stories.", + ) + cls.b4 = Book.objects.create( + title="Some other book", + author=cls.a3, + paperback=True, + opening_line="It was the day my grandmother exploded.", + ) + cls.c1 = Coffee.objects.create(brand="dunkin doughnuts") + cls.c2 = Coffee.objects.create(brand="starbucks") + cls.r1 = Reviewer.objects.create() + cls.r2 = Reviewer.objects.create() + cls.r1.reviewed.add(cls.b2, cls.b3, cls.b4) + + def assertSuccessfulRawQuery( + self, + model, + query, + expected_results, + expected_annotations=(), + ): + """ + Execute the passed query against the passed model and check the output. + """ + results = list(model.objects.raw_aggregate(query)) + expected_results = list(expected_results) + with self.assertNumQueries(0): + self.assertProcessed(model, results, expected_results, expected_annotations) + self.assertAnnotations(results, expected_annotations) + + def assertProcessed(self, model, results, orig, expected_annotations=()): + """Compare the results of a raw query against expected results.""" + self.assertEqual(len(results), len(orig)) + for index, item in enumerate(results): + orig_item = orig[index] + for annotation in expected_annotations: + setattr(orig_item, *annotation) + + for field in model._meta.fields: + # All values on the model are equal. + self.assertEqual(getattr(item, field.attname), getattr(orig_item, field.attname)) + # This includes checking that they are the same type. + self.assertEqual( + type(getattr(item, field.attname)), + type(getattr(orig_item, field.attname)), + ) + + def assertNoAnnotations(self, results): + """The results of a raw query contain no annotations.""" + self.assertAnnotations(results, ()) + + def assertAnnotations(self, results, expected_annotations): + """The passed raw query results contain the expected annotations.""" + if expected_annotations: + for index, result in enumerate(results): + annotation, value = expected_annotations[index] + self.assertTrue(hasattr(result, annotation)) + self.assertEqual(getattr(result, annotation), value) + + def test_rawqueryset_repr(self): + queryset = RawQuerySet(pipeline=[]) + self.assertEqual(repr(queryset), "") + self.assertEqual(repr(queryset.query), "") + + def test_simple_raw_query(self): + """Basic test of raw query with a simple database query.""" + query = [] + authors = Author.objects.all() + self.assertSuccessfulRawQuery(Author, query, authors) + + def test_raw_query_lazy(self): + """ + Raw queries are lazy: they aren't actually executed until they're + iterated over. + """ + q = Author.objects.raw_aggregate([]) + self.assertIsNone(q.query.cursor) + list(q) + self.assertIsNotNone(q.query.cursor) + + def test_FK_raw_query(self): + """ + Test of a simple raw query against a model containing a foreign key. + """ + query = [] + books = Book.objects.all() + self.assertSuccessfulRawQuery(Book, query, books) + + def test_db_column_handler(self): + """ + Test of a simple raw query against a model containing a field with + db_column defined. + """ + query = [] + coffees = Coffee.objects.all() + self.assertSuccessfulRawQuery(Coffee, query, coffees) + + def test_pk_with_mixed_case_db_column(self): + """ + A raw query with a model that has a pk db_column with mixed case. + """ + query = [] + queryset = MixedCaseIDColumn.objects.all() + self.assertSuccessfulRawQuery(MixedCaseIDColumn, query, queryset) + + def test_order_handler(self): + """ + Test of raw raw query's tolerance for columns being returned in any + order. + """ + selects = ( + ("dob, last_name, first_name, id"), + ("last_name, dob, first_name, id"), + ("first_name, last_name, dob, id"), + ) + for select in selects: + select = {col: 1 for col in select.split(", ")} + query = [{"$project": select}] + authors = Author.objects.all() + self.assertSuccessfulRawQuery(Author, query, authors) + + def test_query_representation(self): + """Test representation of raw query.""" + query = [{"$match": {"last_name": "foo"}}] + qset = Author.objects.raw_aggregate(query) + self.assertEqual( + repr(qset), + "", + ) + self.assertEqual( + repr(qset.query), + "", + ) + + def test_many_to_many(self): + """ + Test of a simple raw query against a model containing a m2m field. + """ + query = [] + reviewers = Reviewer.objects.all() + self.assertSuccessfulRawQuery(Reviewer, query, reviewers) + + def test_missing_fields(self): + query = [{"$project": {"id": 1, "first_name": 1, "dob": 1}}] + for author in Author.objects.raw_aggregate(query): + self.assertIsNotNone(author.first_name) + # last_name isn't given, but it will be retrieved on demand. + self.assertIsNotNone(author.last_name) + + def test_missing_fields_without_PK(self): + query = [{"$project": {"first_name": 1, "dob": 1, "_id": 0}}] + msg = "Raw query must include the primary key" + with self.assertRaisesMessage(FieldDoesNotExist, msg): + list(Author.objects.raw_aggregate(query)) + + def test_annotations(self): + query = [ + { + "$project": { + "first_name": 1, + "last_name": 1, + "dob": 1, + "birth_year": {"$year": "$dob"}, + }, + }, + {"$sort": {"_id": 1}}, + ] + expected_annotations = ( + ("birth_year", 1950), + ("birth_year", 1920), + ("birth_year", 1986), + ("birth_year", 1932), + ) + authors = Author.objects.order_by("pk") + self.assertSuccessfulRawQuery(Author, query, authors, expected_annotations) + + def test_multiple_iterations(self): + query = [] + normal_authors = Author.objects.all() + raw_authors = Author.objects.raw_aggregate(query) + + # First Iteration + first_iterations = 0 + for index, raw_author in enumerate(raw_authors): + self.assertEqual(normal_authors[index], raw_author) + first_iterations += 1 + + # Second Iteration + second_iterations = 0 + for index, raw_author in enumerate(raw_authors): + self.assertEqual(normal_authors[index], raw_author) + second_iterations += 1 + + self.assertEqual(first_iterations, second_iterations) + + def test_get_item(self): + # Indexing on RawQuerySets + query = [{"$sort": {"id": 1}}] + third_author = Author.objects.raw_aggregate(query)[2] + self.assertEqual(third_author.first_name, "Bob") + + first_two = Author.objects.raw_aggregate(query)[0:2] + self.assertEqual(len(first_two), 2) + + with self.assertRaises(TypeError): + Author.objects.raw_aggregate(query)["test"] + + def test_inheritance(self): + f = FriendlyAuthor.objects.create( + first_name="Wesley", last_name="Chun", dob=date(1962, 10, 28) + ) + query = [] + self.assertEqual([o.pk for o in FriendlyAuthor.objects.raw_aggregate(query)], [f.pk]) + + def test_query_count(self): + self.assertNumQueries(1, list, Author.objects.raw_aggregate([])) + + def test_subquery_in_raw_sql(self): + list( + Book.objects.raw_aggregate( + [{"$match": {"paperback": {"$ne": None}}}, {"$project": {"id": 1}}] + ) + ) + + def test_db_column_name_is_used_in_raw_query(self): + """ + Regression test that ensures the `column` attribute on the field is + used to generate the list of fields included in the query, as opposed + to the `attname`. This is important when the primary key is a + ForeignKey field because `attname` and `column` are not necessarily the + same. + """ + b = BookFkAsPk.objects.create(book=self.b1) + self.assertEqual( + list( + BookFkAsPk.objects.raw_aggregate([{"$project": {"not_the_default": 1, "_id": 0}}]) + ), + [b], + ) + + def test_result_caching(self): + with self.assertNumQueries(1): + books = Book.objects.raw_aggregate([]) + list(books) + list(books) + + def test_iterator(self): + with self.assertNumQueries(2): + books = Book.objects.raw_aggregate([]) + list(books.iterator()) + list(books.iterator()) + + def test_bool(self): + self.assertIs(bool(Book.objects.raw_aggregate([])), True) + self.assertIs(bool(Book.objects.raw_aggregate([{"$match": {"id": 0}}])), False) + + def test_len(self): + self.assertEqual(len(Book.objects.raw_aggregate([])), 4) + self.assertEqual(len(Book.objects.raw_aggregate([{"$match": {"id": 0}}])), 0)