Skip to content

Commit 3c92a63

Browse files
committed
Use as_mql_idx when creating indexes.
1 parent e1233b1 commit 3c92a63

File tree

4 files changed

+64
-10
lines changed

4 files changed

+64
-10
lines changed

django_mongodb/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,15 @@ def _isnull_operator(a, b):
118118
"iregex": lambda a, b: regex_match(a, b, insensitive=True),
119119
}
120120

121+
mongo_operators_idx = {
122+
"exact": lambda a, b: {a: {"$eq": b}},
123+
"gt": lambda a, b: {a: {"$gt": b}},
124+
"gte": lambda a, b: {a: {"$gte": b}},
125+
"lt": lambda a, b: {a: {"$lt": b}},
126+
"lte": lambda a, b: {a: {"$lte": b}},
127+
"in": lambda a, b: {a: {"$in": b}},
128+
}
129+
121130
display_name = "MongoDB"
122131
vendor = "mongodb"
123132
Database = Database

django_mongodb/indexes.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,7 @@ def _get_condition_mql(self, model, schema_editor):
77
query = Query(model=model, alias_cols=False)
88
where = query.build_where(self.condition)
99
compiler = query.get_compiler(connection=schema_editor.connection)
10-
mql_ = where.as_mql(compiler, schema_editor.connection)
11-
# Transform aggregate() query syntax into find() syntax.
12-
mql = {}
13-
for key in mql_:
14-
col, value = mql_[key]
15-
# multiple conditions don't work yet
16-
if isinstance(col, dict):
17-
return {}
18-
mql[col.lstrip("$")] = {key: value}
19-
return mql
10+
return where.as_mql_idx(compiler, schema_editor.connection)
2011

2112

2213
def register_indexes():

django_mongodb/lookups.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from django.db import NotSupportedError
2+
from django.db.models.expressions import Col
23
from django.db.models.fields.related_lookups import In, MultiColSource, RelatedIn
34
from django.db.models.lookups import (
45
BuiltinLookup,
@@ -17,6 +18,14 @@ def builtin_lookup(self, compiler, connection):
1718
return connection.mongo_operators[self.lookup_name](lhs_mql, value)
1819

1920

21+
def builtin_lookup_idx(self, compiler, connection):
22+
if not isinstance(self.lhs, Col):
23+
raise NotSupportedError("Expressions as indexes are not supported.")
24+
lhs_mql = self.lhs.target.column
25+
value = process_rhs(self, compiler, connection)
26+
return connection.mongo_operators_idx[self.lookup_name](lhs_mql, value)
27+
28+
2029
_field_resolve_expression_parameter = FieldGetDbPrepValueIterableMixin.resolve_expression_parameter
2130

2231

@@ -93,6 +102,7 @@ def uuid_text_mixin(self, compiler, connection): # noqa: ARG001
93102

94103
def register_lookups():
95104
BuiltinLookup.as_mql = builtin_lookup
105+
BuiltinLookup.as_mql_idx = builtin_lookup_idx
96106
FieldGetDbPrepValueIterableMixin.resolve_expression_parameter = (
97107
field_resolve_expression_parameter
98108
)

django_mongodb/query.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,52 @@ def where_node(self, compiler, connection):
302302
return mql
303303

304304

305+
def where_node_idx(self, compiler, connection):
306+
if self.connector == AND:
307+
full_needed, empty_needed = len(self.children), 1
308+
else:
309+
full_needed, empty_needed = 1, len(self.children)
310+
if self.connector == AND:
311+
operator = "$and"
312+
elif self.connector == XOR:
313+
raise NotSupportedError("Xor in indexes is not supported")
314+
else:
315+
operator = "$or"
316+
if self.negated:
317+
raise NotSupportedError("Negated field in indexes is not supported")
318+
children_mql = []
319+
for child in self.children:
320+
try:
321+
mql = child.as_mql_idx(compiler, connection)
322+
except EmptyResultSet:
323+
empty_needed -= 1
324+
except FullResultSet:
325+
full_needed -= 1
326+
else:
327+
if mql:
328+
children_mql.append(mql)
329+
else:
330+
full_needed -= 1
331+
332+
if empty_needed == 0:
333+
raise (FullResultSet if self.negated else EmptyResultSet)
334+
if full_needed == 0:
335+
raise (EmptyResultSet if self.negated else FullResultSet)
336+
337+
if len(children_mql) == 1:
338+
mql = children_mql[0]
339+
elif len(children_mql) > 1:
340+
mql = {operator: children_mql}
341+
else:
342+
mql = {}
343+
if not mql:
344+
raise FullResultSet
345+
return mql
346+
347+
305348
def register_nodes():
306349
ExtraWhere.as_mql = extra_where
307350
Join.as_mql = join
308351
NothingNode.as_mql = NothingNode.as_sql
309352
WhereNode.as_mql = where_node
353+
WhereNode.as_mql_idx = where_node_idx

0 commit comments

Comments
 (0)