Skip to content

Commit dabf953

Browse files
committed
Refactor index creation
1 parent 64b1c10 commit dabf953

File tree

2 files changed

+48
-40
lines changed

2 files changed

+48
-40
lines changed

django_mongodb_backend/indexes.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from django.db.models.lookups import BuiltinLookup
44
from django.db.models.sql.query import Query
55
from django.db.models.sql.where import AND, XOR, WhereNode
6+
from pymongo import ASCENDING, DESCENDING
7+
from pymongo.operations import IndexModel
68

79
from .query_utils import process_rhs
810

@@ -58,7 +60,48 @@ def where_node_idx(self, compiler, connection):
5860
return mql
5961

6062

63+
def create_mongodb_index(self, model, schema_editor, field=None, unique=False, column_prefix=""):
64+
from collections import defaultdict
65+
66+
if self.contains_expressions:
67+
return None
68+
kwargs = {}
69+
filter_expression = defaultdict(dict)
70+
if self.condition:
71+
filter_expression.update(self._get_condition_mql(model, schema_editor))
72+
if unique:
73+
kwargs["unique"] = True
74+
# Indexing on $type matches the value of most SQL databases by
75+
# allowing multiple null values for the unique constraint.
76+
if field:
77+
column = column_prefix + field.column
78+
filter_expression[column].update({"$type": field.db_type(schema_editor.connection)})
79+
else:
80+
for field_name, _ in self.fields_orders:
81+
field_ = model._meta.get_field(field_name)
82+
filter_expression[field_.column].update(
83+
{"$type": field_.db_type(schema_editor.connection)}
84+
)
85+
if filter_expression:
86+
kwargs["partialFilterExpression"] = filter_expression
87+
index_orders = (
88+
[(column_prefix + field.column, ASCENDING)]
89+
if field
90+
else [
91+
# order is "" if ASCENDING or "DESC" if DESCENDING (see
92+
# django.db.models.indexes.Index.fields_orders).
93+
(
94+
column_prefix + model._meta.get_field(field_name).column,
95+
ASCENDING if order == "" else DESCENDING,
96+
)
97+
for field_name, order in self.fields_orders
98+
]
99+
)
100+
return IndexModel(index_orders, name=self.name, **kwargs)
101+
102+
61103
def register_indexes():
62104
BuiltinLookup.as_mql_idx = builtin_lookup_idx
63105
Index._get_condition_mql = _get_condition_mql
106+
Index.create_mongodb_index = create_mongodb_index
64107
WhereNode.as_mql_idx = where_node_idx

django_mongodb_backend/schema.py

Lines changed: 5 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
1-
from collections import defaultdict
2-
31
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
42
from django.db.models import Index, UniqueConstraint
5-
from pymongo import ASCENDING, DESCENDING
6-
from pymongo.operations import IndexModel
73

84
from .fields import EmbeddedModelField
95
from .query import wrap_database_errors
@@ -264,43 +260,12 @@ def alter_unique_together(
264260
def add_index(
265261
self, model, index, *, field=None, unique=False, column_prefix="", parent_model=None
266262
):
267-
if index.contains_expressions:
268-
return
269-
kwargs = {}
270-
filter_expression = defaultdict(dict)
271-
if index.condition:
272-
filter_expression.update(index._get_condition_mql(model, self))
273-
if unique:
274-
kwargs["unique"] = True
275-
# Indexing on $type matches the value of most SQL databases by
276-
# allowing multiple null values for the unique constraint.
277-
if field:
278-
column = column_prefix + field.column
279-
filter_expression[column].update({"$type": field.db_type(self.connection)})
280-
else:
281-
for field_name, _ in index.fields_orders:
282-
field_ = model._meta.get_field(field_name)
283-
filter_expression[field_.column].update(
284-
{"$type": field_.db_type(self.connection)}
285-
)
286-
if filter_expression:
287-
kwargs["partialFilterExpression"] = filter_expression
288-
index_orders = (
289-
[(column_prefix + field.column, ASCENDING)]
290-
if field
291-
else [
292-
# order is "" if ASCENDING or "DESC" if DESCENDING (see
293-
# django.db.models.indexes.Index.fields_orders).
294-
(
295-
column_prefix + model._meta.get_field(field_name).column,
296-
ASCENDING if order == "" else DESCENDING,
297-
)
298-
for field_name, order in index.fields_orders
299-
]
263+
idx = index.create_mongodb_index(
264+
model, self, field=field, unique=unique, column_prefix=column_prefix
300265
)
301-
idx = IndexModel(index_orders, name=index.name, **kwargs)
302-
model = parent_model or model
303-
self.get_collection(model._meta.db_table).create_indexes([idx])
266+
if idx:
267+
model = parent_model or model
268+
self.get_collection(model._meta.db_table).create_indexes([idx])
304269

305270
def _add_composed_index(self, model, field_names, column_prefix="", parent_model=None):
306271
"""Add an index on the given list of field_names."""

0 commit comments

Comments
 (0)