Skip to content

Commit 34be80f

Browse files
committed
Handle F keystransform.
1 parent 396e2bf commit 34be80f

File tree

3 files changed

+42
-13
lines changed

3 files changed

+42
-13
lines changed

django_mongodb/features.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,8 @@ class DatabaseFeatures(BaseDatabaseFeatures):
299299
"model_fields.test_jsonfield.JSONFieldTests.test_db_check_constraints",
300300
},
301301
"Mongodb's Null behaviour is different from sql's": {
302-
"model_fields.test_jsonfield.TestQuerying.test_none_key_exclude",
303-
"model_fields.test_jsonfield.TestQuerying.test_isnull_key",
302+
"model_fields.test_jsonfield.TestQuerying.test_none_key_and_exact_lookup",
303+
# "model_fields.test_jsonfield.TestQuerying.test_isnull_key",
304304
},
305305
"Pipeline filtering": {"model_fields.test_jsonfield.TestQuerying.test_icontains"},
306306
}

django_mongodb/fields.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,15 @@ def from_db_value(self, value, expression, connection):
6565
def json_process_rhs(node, compiler, connection):
6666
_, value = node.process_rhs(compiler, connection)
6767

68-
# Django's framework transform the [None] into a [null],
69-
# we have to revertit.
70-
if value == ["null"]:
71-
value = [None]
72-
7368
lookup_name = node.lookup_name
7469
if lookup_name not in ("in", "range"):
7570
value = value[0] if len(value) > 0 else []
71+
else:
72+
result_value = []
73+
for ind, elem in enumerate(node.rhs):
74+
item = f"${value[ind]}" if isinstance(elem, KeyTransform) else value[ind]
75+
result_value.append(item)
76+
value = result_value
7677

7778
return value
7879

@@ -96,24 +97,31 @@ def contained_by(self, compiler, connection): # noqa: ARG001
9697

9798

9899
def json_exact(self, compiler, connection):
99-
rhs_mql = json_process_rhs(self, compiler, connection)
100100
lhs_mql = process_lhs(self, compiler, connection)
101+
rhs_mql = json_process_rhs(self, compiler, connection)
102+
if rhs_mql == "null":
103+
return {"$or": [{lhs_mql: {"$eq": None}}, {lhs_mql: {"$exists": False}}]}
104+
# return {lhs_mql: {"$eq": None, "$exists": True}}
105+
# return key_transform_isnull(self, compiler, connection)
101106
return {lhs_mql: {"$eq": rhs_mql, "$exists": True}}
102107

103108

104109
def key_transform_isnull(self, compiler, connection):
105110
lhs_mql = process_lhs(self, compiler, connection)
106111
rhs_mql = json_process_rhs(self, compiler, connection)
107-
if rhs_mql is False:
108-
return {lhs_mql: {"$neq": None}}
109-
return {"$or": [{lhs_mql: {"$eq": None}}, {lhs_mql: {"$exists": False}}]}
112+
# if rhs_mql is False:
113+
# return {lhs_mql: {"$neq": None}}
114+
# return {"$or": [{lhs_mql: {"$eq": None}}, {lhs_mql: {"$exists": False}}]}
115+
116+
# https://code.djangoproject.com/ticket/32252
117+
return {lhs_mql: {"$exists": not rhs_mql}}
110118

111119

112120
def key_transform_in(self, compiler, connection):
113121
lhs_mql = process_lhs(self, compiler, connection)
114122
value = json_process_rhs(self, compiler, connection)
115123
rhs_mql = connection.operators[self.lookup_name](value)
116-
return {lhs_mql: rhs_mql}
124+
return {"$expr": {lhs_mql: rhs_mql}}
117125

118126

119127
def has_key_lookup(self, compiler, connection):

django_mongodb/lookups.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
from django.db import NotSupportedError
22
from django.db.models.fields.related_lookups import In, MultiColSource, RelatedIn
3-
from django.db.models.lookups import BuiltinLookup, Exact, IsNull, UUIDTextMixin
3+
from django.db.models.lookups import (
4+
BuiltinLookup,
5+
Exact,
6+
FieldGetDbPrepValueIterableMixin,
7+
IsNull,
8+
UUIDTextMixin,
9+
)
410

511
from .query_utils import process_lhs, process_rhs
612

@@ -42,10 +48,25 @@ def uuid_text_mixin(self, compiler, connection): # noqa: ARG001
4248
raise NotSupportedError("Pattern lookups on UUIDField are not supported.")
4349

4450

51+
_resolve_expression_parameter = FieldGetDbPrepValueIterableMixin.resolve_expression_parameter
52+
53+
54+
def resolve_expression_parameter(self, compiler, connection, sql, param):
55+
if connection.vendor == "mongodb":
56+
params = [param]
57+
if hasattr(param, "resolve_expression"):
58+
param = param.resolve_expression(compiler.query)
59+
if hasattr(param, "as_mql"):
60+
params = [param.as_mql(compiler, connection)]
61+
return "", params
62+
return _resolve_expression_parameter(self, compiler, connection, sql, param)
63+
64+
4565
def register_lookups():
4666
BuiltinLookup.as_mql = builtin_lookup
4767
BuiltinLookup.as_mql_agg = builtin_lookup_agg
4868
Exact.as_mql = exact
4969
In.as_mql = RelatedIn.as_mql = in_
5070
IsNull.as_mql = is_null
5171
UUIDTextMixin.as_mql = uuid_text_mixin
72+
FieldGetDbPrepValueIterableMixin.resolve_expression_parameter = resolve_expression_parameter

0 commit comments

Comments
 (0)