Skip to content

Commit 918237a

Browse files
committed
POC: Manage sub array queries with a different transform path.
1 parent c357052 commit 918237a

File tree

1 file changed

+51
-20
lines changed

1 file changed

+51
-20
lines changed

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -73,31 +73,53 @@ def as_mql(self, compiler, connection):
7373
if isinstance(value, models.Model):
7474
value, emf_data = self.model_to_dict(value)
7575
# Get conditions for any nested EmbeddedModelFields.
76-
conditions = self.get_conditions({"$$item": (value, emf_data)})
76+
conditions = self.get_conditions({lhs_mql[1]: (value, emf_data)})
7777
return {
7878
"$anyElementTrue": {
79-
"$map": {"input": lhs_mql, "as": "item", "in": {"$and": conditions}}
79+
"$ifNull": [
80+
{
81+
"$map": {
82+
"input": lhs_mql[0],
83+
"as": "item",
84+
"in": {"$and": conditions},
85+
}
86+
},
87+
[],
88+
]
8089
}
8190
}
82-
lhs_mql = process_lhs(self.lhs, compiler, connection)
8391
return {
8492
"$anyElementTrue": {
85-
"$map": {
86-
"input": lhs_mql,
87-
"as": "item",
88-
"in": {"$eq": [f"$$item.{self.lhs.key_name}", value]},
89-
}
93+
"$ifNull": [
94+
{
95+
"$map": {
96+
"input": lhs_mql[0],
97+
"as": "item",
98+
"in": {"$eq": [lhs_mql[1], value]},
99+
}
100+
},
101+
[],
102+
]
90103
}
91104
}
92105
return connection.mongo_operators[self.lookup_name](lhs_mql, value)
93106

94107

95108
class KeyTransform(Transform):
96109
# it should be different class than EMF keytransform even most of the methods are equal.
97-
def __init__(self, key_name, ref_field, *args, **kwargs):
110+
def __init__(self, key_name, base_field, *args, **kwargs):
98111
super().__init__(*args, **kwargs)
99-
self.key_name = str(key_name)
100-
self.ref_field = ref_field
112+
self.base_field = base_field
113+
# TODO: Need to create a column, will refactor this thing.
114+
column_target = base_field.clone()
115+
column_target.db_column = f"$item.{key_name}"
116+
column_target.set_attributes_from_name(f"$item.{key_name}")
117+
self._lhs = Col(None, column_target)
118+
self._sub_transform = None
119+
120+
def __call__(self, this, *args, **kwargs):
121+
self._lhs = self._sub_transform(self._lhs, *args, **kwargs)
122+
return self
101123

102124
def get_lookup(self, name):
103125
return self.output_field.get_lookup(name)
@@ -107,33 +129,42 @@ def get_transform(self, name):
107129
Validate that `name` is either a field of an embedded model or a
108130
lookup on an embedded model's field.
109131
"""
110-
if transform := self.ref_field.get_transform(name):
111-
return transform
112-
suggested_lookups = difflib.get_close_matches(name, self.ref_field.get_lookups())
132+
if isinstance(self._lhs, Transform):
133+
transform = self._lhs.get_transform(name)
134+
else:
135+
transform = self.base_field.get_transform(name)
136+
if transform:
137+
self._sub_transform = transform
138+
return self
139+
suggested_lookups = difflib.get_close_matches(name, self.base_field.get_lookups())
113140
if suggested_lookups:
114141
suggested_lookups = " or ".join(suggested_lookups)
115142
suggestion = f", perhaps you meant {suggested_lookups}?"
116143
else:
117144
suggestion = "."
118145
raise FieldDoesNotExist(
119146
f"Unsupported lookup '{name}' for "
120-
f"{self.ref_field.__class__.__name__} '{self.ref_field.name}'"
147+
f"{self.base_field.__class__.__name__} '{self.base_field.name}'"
121148
f"{suggestion}"
122149
)
123150

124151
def as_mql(self, compiler, connection):
152+
if isinstance(self._lhs, Transform):
153+
inner_lhs_mql = self._lhs.as_mql(compiler, connection)
154+
else:
155+
inner_lhs_mql = None
125156
lhs_mql = process_lhs(self, compiler, connection)
126-
return f"{lhs_mql}.{self.key_name}"
157+
return lhs_mql, inner_lhs_mql
127158

128159
@property
129160
def output_field(self):
130-
return EmbeddedModelArrayField(self.ref_field)
161+
return EmbeddedModelArrayField(self.base_field)
131162

132163

133164
class KeyTransformFactory:
134-
def __init__(self, key_name, ref_field):
165+
def __init__(self, key_name, base_field):
135166
self.key_name = key_name
136-
self.ref_field = ref_field
167+
self.base_field = base_field
137168

138169
def __call__(self, *args, **kwargs):
139-
return KeyTransform(self.key_name, self.ref_field, *args, **kwargs)
170+
return KeyTransform(self.key_name, self.base_field, *args, **kwargs)

0 commit comments

Comments
 (0)