Skip to content

Commit 9632ad6

Browse files
committed
Use as_mql_idx when creating indexes.
1 parent 5bfd3f8 commit 9632ad6

File tree

4 files changed

+64
-10
lines changed

4 files changed

+64
-10
lines changed

django_mongodb/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,15 @@ def _isnull_operator(a, b):
119119
"iregex": lambda a, b: regex_match(a, b, insensitive=True),
120120
}
121121

122+
mongo_operators_idx = {
123+
"exact": lambda a, b: {a: {"$eq": b}},
124+
"gt": lambda a, b: {a: {"$gt": b}},
125+
"gte": lambda a, b: {a: {"$gte": b}},
126+
"lt": lambda a, b: {a: {"$lt": b}},
127+
"lte": lambda a, b: {a: {"$lte": b}},
128+
"in": lambda a, b: {a: {"$in": b}},
129+
}
130+
122131
display_name = "MongoDB"
123132
vendor = "mongodb"
124133
Database = Database

django_mongodb/indexes.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,7 @@ def _get_condition_mql(self, model, schema_editor):
77
query = Query(model=model, alias_cols=False)
88
where = query.build_where(self.condition)
99
compiler = query.get_compiler(connection=schema_editor.connection)
10-
mql_ = where.as_mql(compiler, schema_editor.connection)
11-
# Transform aggregate() query syntax into find() syntax.
12-
mql = {}
13-
for key in mql_:
14-
col, value = mql_[key]
15-
# multiple conditions don't work yet
16-
if isinstance(col, dict):
17-
return {}
18-
mql[col.lstrip("$")] = {key: value}
19-
return mql
10+
return where.as_mql_idx(compiler, schema_editor.connection)
2011

2112

2213
def register_indexes():

django_mongodb/lookups.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from django.db import NotSupportedError
2+
from django.db.models.expressions import Col
23
from django.db.models.fields.related_lookups import In, MultiColSource, RelatedIn
34
from django.db.models.lookups import (
45
BuiltinLookup,
@@ -17,6 +18,14 @@ def builtin_lookup(self, compiler, connection):
1718
return connection.mongo_operators[self.lookup_name](lhs_mql, value)
1819

1920

21+
def builtin_lookup_idx(self, compiler, connection):
22+
if not isinstance(self.lhs, Col):
23+
raise NotSupportedError("Expressions as indexes are not supported.")
24+
lhs_mql = self.lhs.target.column
25+
value = process_rhs(self, compiler, connection)
26+
return connection.mongo_operators_idx[self.lookup_name](lhs_mql, value)
27+
28+
2029
_field_resolve_expression_parameter = FieldGetDbPrepValueIterableMixin.resolve_expression_parameter
2130

2231

@@ -93,6 +102,7 @@ def uuid_text_mixin(self, compiler, connection): # noqa: ARG001
93102

94103
def register_lookups():
95104
BuiltinLookup.as_mql = builtin_lookup
105+
BuiltinLookup.as_mql_idx = builtin_lookup_idx
96106
FieldGetDbPrepValueIterableMixin.resolve_expression_parameter = (
97107
field_resolve_expression_parameter
98108
)

django_mongodb/query.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,52 @@ def where_node(self, compiler, connection):
297297
return mql
298298

299299

300+
def where_node_idx(self, compiler, connection):
301+
if self.connector == AND:
302+
full_needed, empty_needed = len(self.children), 1
303+
else:
304+
full_needed, empty_needed = 1, len(self.children)
305+
if self.connector == AND:
306+
operator = "$and"
307+
elif self.connector == XOR:
308+
raise NotSupportedError("Xor in indexes is not supported")
309+
else:
310+
operator = "$or"
311+
if self.negated:
312+
raise NotSupportedError("Negated field in indexes is not supported")
313+
children_mql = []
314+
for child in self.children:
315+
try:
316+
mql = child.as_mql_idx(compiler, connection)
317+
except EmptyResultSet:
318+
empty_needed -= 1
319+
except FullResultSet:
320+
full_needed -= 1
321+
else:
322+
if mql:
323+
children_mql.append(mql)
324+
else:
325+
full_needed -= 1
326+
327+
if empty_needed == 0:
328+
raise (FullResultSet if self.negated else EmptyResultSet)
329+
if full_needed == 0:
330+
raise (EmptyResultSet if self.negated else FullResultSet)
331+
332+
if len(children_mql) == 1:
333+
mql = children_mql[0]
334+
elif len(children_mql) > 1:
335+
mql = {operator: children_mql}
336+
else:
337+
mql = {}
338+
if not mql:
339+
raise FullResultSet
340+
return mql
341+
342+
300343
def register_nodes():
301344
ExtraWhere.as_mql = extra_where
302345
Join.as_mql = join
303346
NothingNode.as_mql = NothingNode.as_sql
304347
WhereNode.as_mql = where_node
348+
WhereNode.as_mql_idx = where_node_idx

0 commit comments

Comments
 (0)