Skip to content

Commit a3481ed

Browse files
committed
add aggregation function to keytransforms.
1 parent 3bf106f commit a3481ed

File tree

1 file changed

+31
-18
lines changed

1 file changed

+31
-18
lines changed

django_mongodb/fields/auto.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,9 @@ def from_db_value(self, value, expression, connection):
6464

6565
def json_process_rhs(node, compiler, connection):
6666
_, value = node.process_rhs(compiler, connection)
67-
6867
lookup_name = node.lookup_name
6968
if lookup_name not in ("in", "range"):
7069
value = value[0] if len(value) > 0 else []
71-
else:
72-
result_value = []
73-
for ind, elem in enumerate(node.rhs):
74-
item = f"${value[ind]}" if isinstance(elem, KeyTransform) else value[ind]
75-
result_value.append(item)
76-
value = result_value
7770

7871
return value
7972

@@ -97,35 +90,31 @@ def contained_by(self, compiler, connection): # noqa: ARG001
9790

9891

9992
def json_exact(self, compiler, connection):
100-
lhs_mql = process_lhs(self, compiler, connection)
93+
lhs_mql = process_lhs(self, compiler, connection, bare_column_ref=True)
10194
rhs_mql = json_process_rhs(self, compiler, connection)
10295
if rhs_mql == "null":
10396
return {"$or": [{lhs_mql: {"$eq": None}}, {lhs_mql: {"$exists": False}}]}
104-
# return {lhs_mql: {"$eq": None, "$exists": True}}
105-
# return key_transform_isnull(self, compiler, connection)
10697
return {lhs_mql: {"$eq": rhs_mql, "$exists": True}}
10798

10899

109100
def key_transform_isnull(self, compiler, connection):
110-
lhs_mql = process_lhs(self, compiler, connection)
101+
lhs_mql = process_lhs(self, compiler, connection, bare_column_ref=True)
111102
rhs_mql = json_process_rhs(self, compiler, connection)
112-
# if rhs_mql is False:
113-
# return {lhs_mql: {"$neq": None}}
114-
# return {"$or": [{lhs_mql: {"$eq": None}}, {lhs_mql: {"$exists": False}}]}
115103

116104
# https://code.djangoproject.com/ticket/32252
117105
return {lhs_mql: {"$exists": not rhs_mql}}
118106

119107

120108
def key_transform_in(self, compiler, connection):
121-
lhs_mql = process_lhs(self, compiler, connection)
109+
lhs_mql = key_transform_agg(self.lhs, compiler, connection)
110+
bare_lhs_mql = process_lhs(self, compiler, connection, bare_column_ref=True)
122111
value = json_process_rhs(self, compiler, connection)
123-
rhs_mql = connection.operators[self.lookup_name](value)
124-
return {"$expr": {lhs_mql: rhs_mql}}
112+
expr = connection.mongo_aggregations[self.lookup_name](lhs_mql, value)
113+
return {"$expr": expr, bare_lhs_mql: {"$exists": True}}
125114

126115

127116
def has_key_lookup(self, compiler, connection):
128-
lhs = process_lhs(self, compiler, connection)
117+
lhs = process_lhs(self, compiler, connection, bare_column_ref=True)
129118
rhs = self.rhs
130119
if not isinstance(rhs, (list | tuple)):
131120
rhs = [rhs]
@@ -148,10 +137,34 @@ def has_key_lookup(self, compiler, connection):
148137
return {self.mongo_operator: keys}
149138

150139

140+
def key_transform_agg(self, compiler, connection):
141+
key_transforms = [self.key_name]
142+
previous = self.lhs
143+
while isinstance(previous, KeyTransform):
144+
key_transforms.insert(0, previous.key_name)
145+
previous = previous.lhs
146+
lhs_mql = previous.as_mql(compiler, connection)
147+
result = f"{lhs_mql}"
148+
for key in key_transforms:
149+
get_field = {"$getField": {"input": result, "field": key}}
150+
if key.isdigit():
151+
result = {
152+
"$cond": {
153+
"if": {"$isArray": result},
154+
"then": {"$arrayElemAt": [result, int(key)]},
155+
"else": get_field,
156+
}
157+
}
158+
else:
159+
result = get_field
160+
return result
161+
162+
151163
def load_fields():
152164
JSONField.from_db_value = from_db_value
153165
DataContains.as_mql = data_contains
154166
KeyTransform.as_mql = key_transform
167+
KeyTransform.as_mql_agg = key_transform_agg
155168
JSONExact.as_mql = json_exact
156169
ContainedBy.as_mql = contained_by
157170
HasKeyLookup.as_mql = has_key_lookup

0 commit comments

Comments
 (0)