Skip to content

Commit 575957c

Browse files
committed
add support for partial indexes
1 parent 85d32a5 commit 575957c

File tree

6 files changed

+53
-9
lines changed

6 files changed

+53
-9
lines changed

.github/workflows/test-python.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ jobs:
8686
from_db_value
8787
generic_relations
8888
generic_relations_regress
89+
indexes
8990
introspection
9091
known_related_objects
9192
lookup

django_mongodb/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
from .expressions import register_expressions # noqa: E402
1111
from .fields import register_fields # noqa: E402
1212
from .functions import register_functions # noqa: E402
13+
from .indexes import register_indexes # noqa: E402
1314
from .lookups import register_lookups # noqa: E402
1415
from .query import register_nodes # noqa: E402
1516

1617
register_aggregates()
1718
register_expressions()
1819
register_fields()
1920
register_functions()
21+
register_indexes()
2022
register_lookups()
2123
register_nodes()

django_mongodb/expressions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def case(self, compiler, connection):
5151

5252
def col(self, compiler, connection): # noqa: ARG001
5353
# Add the column's collection's alias for columns in joined collections.
54-
prefix = f"{self.alias}." if self.alias != compiler.collection_name else ""
54+
has_alias = self.alias and self.alias != compiler.collection_name
55+
prefix = f"{self.alias}." if has_alias else ""
5556
return f"${prefix}{self.target.column}"
5657

5758

django_mongodb/features.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
2323
# BSON Date type doesn't support microsecond precision.
2424
supports_microsecond_precision = False
2525
supports_paramstyle_pyformat = False
26-
# Not implemented.
27-
supports_partial_indexes = False
2826
supports_select_difference = False
2927
supports_select_intersection = False
3028
supports_sequence_reset = False
@@ -72,6 +70,9 @@ class DatabaseFeatures(BaseDatabaseFeatures):
7270
"backends.tests.ThreadTests.test_pass_connection_between_threads",
7371
"backends.tests.ThreadTests.test_closing_non_shared_connections",
7472
"backends.tests.ThreadTests.test_default_connection_thread_local",
73+
# TODO:
74+
"indexes.tests.PartialIndexTests.test_is_null_condition",
75+
"indexes.tests.PartialIndexTests.test_multiple_conditions",
7576
}
7677
# $bitAnd, #bitOr, and $bitXor are new in MongoDB 6.3.
7778
_django_test_expected_failures_bitwise = {

django_mongodb/indexes.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from django.db.models import Index
2+
from django.db.models.sql.query import Query
3+
4+
5+
def _get_condition_mql(self, model, schema_editor):
6+
"""Analogous to Index._get_condition_sql()."""
7+
query = Query(model=model, alias_cols=False)
8+
where = query.build_where(self.condition)
9+
compiler = query.get_compiler(connection=schema_editor.connection)
10+
mql_ = where.as_mql(compiler, schema_editor.connection)
11+
# Transform aggregate() query syntax into find() syntax.
12+
mql = {}
13+
for key in mql_:
14+
col, value = mql_[key]
15+
# multiple conditions don't work yet
16+
if isinstance(col, dict):
17+
return {}
18+
mql[col.lstrip("$")] = {key: value}
19+
return mql
20+
21+
22+
def register_indexes():
23+
Index._get_condition_mql = _get_condition_mql

django_mongodb/schema.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections import defaultdict
2+
13
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
24
from django.db.models import Index, UniqueConstraint
35
from pymongo import ASCENDING, DESCENDING
@@ -166,17 +168,23 @@ def add_index(self, model, index, field=None, unique=False):
166168
if index.contains_expressions:
167169
return
168170
kwargs = {}
171+
filter_expression = defaultdict(dict)
172+
if index.condition:
173+
filter_expression.update(index._get_condition_mql(model, self))
169174
if unique:
170-
filter_expression = {}
175+
kwargs["unique"] = True
171176
if field:
172-
filter_expression[field.column] = {"$type": field.db_type(self.connection)}
177+
filter_expression[field.column].update({"$type": field.db_type(self.connection)})
173178
else:
174179
for field_name, _ in index.fields_orders:
175180
field_ = model._meta.get_field(field_name)
176-
filter_expression[field_.column] = {"$type": field_.db_type(self.connection)}
181+
filter_expression[field_.column].update(
182+
{"$type": field_.db_type(self.connection)}
183+
)
177184
# Use partialFilterExpression to allow multiple null values for a
178185
# unique constraint.
179-
kwargs = {"partialFilterExpression": filter_expression, "unique": True}
186+
if filter_expression:
187+
kwargs["partialFilterExpression"] = filter_expression
180188
index_orders = (
181189
[(field.column, ASCENDING)]
182190
if field
@@ -260,7 +268,11 @@ def add_constraint(self, model, constraint, field=None):
260268
expressions=constraint.expressions,
261269
nulls_distinct=constraint.nulls_distinct,
262270
):
263-
idx = Index(fields=constraint.fields, name=constraint.name)
271+
idx = Index(
272+
fields=constraint.fields,
273+
condition=constraint.condition,
274+
name=constraint.name,
275+
)
264276
self.add_index(model, idx, field=field, unique=True)
265277

266278
def _add_field_unique(self, model, field):
@@ -276,7 +288,11 @@ def remove_constraint(self, model, constraint):
276288
expressions=constraint.expressions,
277289
nulls_distinct=constraint.nulls_distinct,
278290
):
279-
idx = Index(fields=constraint.fields, name=constraint.name)
291+
idx = Index(
292+
fields=constraint.fields,
293+
condition=constraint.condition,
294+
name=constraint.name,
295+
)
280296
self.remove_index(model, idx)
281297

282298
def _remove_field_unique(self, model, field):

0 commit comments

Comments
 (0)