Skip to content

Commit 611fcf0

Browse files
committed
handle paramteres as expressions
1 parent af1cc4a commit 611fcf0

File tree

3 files changed

+147
-101
lines changed

3 files changed

+147
-101
lines changed

django_mongodb_backend/expressions/search.py

Lines changed: 102 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from django.db import NotSupportedError
2-
from django.db.models import Expression, FloatField
2+
from django.db.models import Expression, FloatField, JSONField
3+
from django.db.models.expressions import F, Value
34

45

56
class Operator:
@@ -75,42 +76,57 @@ def __repr__(self):
7576
def as_sql(self, compiler, connection):
7677
return "", []
7778

79+
def _get_indexed_fields(self, mappings):
80+
for field, definition in mappings.get("fields", {}).items():
81+
yield field
82+
for path in self._get_indexed_fields(definition):
83+
yield f"{field}.{path}"
84+
7885
def _get_query_index(self, fields, compiler):
7986
fields = set(fields)
8087
for search_indexes in compiler.collection.list_search_indexes():
8188
mappings = search_indexes["latestDefinition"]["mappings"]
82-
if mappings["dynamic"] or fields.issubset(set(mappings["fields"])):
89+
indexed_fields = set(self._get_indexed_fields(mappings))
90+
if mappings["dynamic"] or fields.issubset(indexed_fields):
8391
return search_indexes["name"]
8492
return "default"
8593

86-
def search_operator(self):
94+
def search_operator(self, compiler, connection):
8795
raise NotImplementedError
8896

8997
def as_mql(self, compiler, connection):
90-
index = self._get_query_index(self.get_search_fields(), compiler)
91-
return {"$search": {**self.search_operator(), "index": index}}
98+
index = self._get_query_index(self.get_search_fields(compiler, connection), compiler)
99+
return {"$search": {**self.search_operator(compiler, connection), "index": index}}
92100

93101

94102
class SearchAutocomplete(SearchExpression):
95-
def __init__(self, path, query, fuzzy=None, score=None):
96-
self.path = path
97-
self.query = query
103+
def __init__(self, path, query, fuzzy=None, token_order=None, score=None):
104+
self.path = F(path) if isinstance(path, str) else path
105+
self.query = Value(query) if not hasattr(query, "resolve_expression") else query
106+
if fuzzy is not None and not hasattr(fuzzy, "resolve_expression"):
107+
fuzzy = Value(fuzzy, output_field=JSONField())
98108
self.fuzzy = fuzzy
109+
if token_order is not None and not hasattr(token_order, "resolve_expression"):
110+
token_order = Value(token_order)
111+
self.token_order = token_order
99112
self.score = score
100113
super().__init__()
101114

102-
def get_search_fields(self):
103-
return {self.path}
115+
def get_search_fields(self, compiler, connection):
116+
# Shall i implement resolve_something? I think I have to do
117+
return {self.path.as_mql(compiler, connection, as_path=True)}
104118

105-
def search_operator(self):
119+
def search_operator(self, compiler, connection):
106120
params = {
107-
"path": self.path,
108-
"query": self.query,
121+
"path": self.path.as_mql(compiler, connection, as_path=True),
122+
"query": self.query.as_mql(compiler, connection),
109123
}
110124
if self.score is not None:
111-
params["score"] = self.score
125+
params["score"] = self.score.as_mql(compiler, connection)
112126
if self.fuzzy is not None:
113-
params["fuzzy"] = self.fuzzy
127+
params["fuzzy"] = self.fuzzy.as_mql(compiler, connection)
128+
if self.token_order is not None:
129+
params["tokenOrder"] = self.token_order.as_mql(compiler, connection)
114130
return {"autocomplete": params}
115131

116132

@@ -121,16 +137,16 @@ def __init__(self, path, value, score=None):
121137
self.score = score
122138
super().__init__()
123139

124-
def get_search_fields(self):
125-
return {self.path}
140+
def get_search_fields(self, compiler, connection):
141+
return {self.path.as_mql(compiler, connection, as_path=True)}
126142

127-
def search_operator(self):
143+
def search_operator(self, compiler, connection):
128144
params = {
129-
"path": self.path,
130-
"value": self.value,
145+
"path": self.path.as_mql(compiler, connection, as_path=True),
146+
"value": self.value.as_mql(compiler, connection, as_path=True),
131147
}
132148
if self.score is not None:
133-
params["score"] = self.score
149+
params["score"] = self.score.as_mql(compiler, connection, as_path=True)
134150
return {"equals": params}
135151

136152

@@ -140,15 +156,15 @@ def __init__(self, path, score=None):
140156
self.score = score
141157
super().__init__()
142158

143-
def get_search_fields(self):
144-
return {self.path}
159+
def get_search_fields(self, compiler, connection):
160+
return {self.path.as_mql(compiler, connection, as_path=True)}
145161

146-
def search_operator(self):
162+
def search_operator(self, compiler, connection):
147163
params = {
148-
"path": self.path,
164+
"path": self.path.as_mql(compiler, connection, as_path=True),
149165
}
150166
if self.score is not None:
151-
params["score"] = self.score
167+
params["score"] = self.score.definitions
152168
return {"exists": params}
153169

154170

@@ -159,16 +175,16 @@ def __init__(self, path, value, score=None):
159175
self.score = score
160176
super().__init__()
161177

162-
def get_search_fields(self):
163-
return {self.path}
178+
def get_search_fields(self, compiler, connection):
179+
return {self.path.as_mql(compiler, connection, as_path=True)}
164180

165-
def search_operator(self):
181+
def search_operator(self, compiler, connection):
166182
params = {
167-
"path": self.path,
168-
"value": self.value,
183+
"path": self.path.as_mql(compiler, connection, as_path=True),
184+
"value": self.value.as_mql(compiler, connection, as_path=True),
169185
}
170186
if self.score is not None:
171-
params["score"] = self.score
187+
params["score"] = self.score.definitions
172188
return {"in": params}
173189

174190

@@ -181,20 +197,20 @@ def __init__(self, path, query, slop=None, synonyms=None, score=None):
181197
self.synonyms = synonyms
182198
super().__init__()
183199

184-
def get_search_fields(self):
185-
return {self.path}
200+
def get_search_fields(self, compiler, connection):
201+
return {self.path.as_mql(compiler, connection, as_path=True)}
186202

187-
def search_operator(self):
203+
def search_operator(self, compiler, connection):
188204
params = {
189-
"path": self.path,
190-
"query": self.query,
205+
"path": self.path.as_mql(compiler, connection, as_path=True),
206+
"query": self.query.as_mql(compiler, connection, as_path=True),
191207
}
192208
if self.score is not None:
193-
params["score"] = self.score
209+
params["score"] = self.score.as_mql(compiler, connection, as_path=True)
194210
if self.slop is not None:
195-
params["slop"] = self.slop
211+
params["slop"] = self.slop.as_mql(compiler, connection, as_path=True)
196212
if self.synonyms is not None:
197-
params["synonyms"] = self.synonyms
213+
params["synonyms"] = self.synonyms.as_mql(compiler, connection, as_path=True)
198214
return {"phrase": params}
199215

200216

@@ -205,16 +221,16 @@ def __init__(self, path, query, score=None):
205221
self.score = score
206222
super().__init__()
207223

208-
def get_search_fields(self):
209-
return {self.path}
224+
def get_search_fields(self, compiler, connection):
225+
return {self.path.as_mql(compiler, connection, as_path=True)}
210226

211-
def search_operator(self):
227+
def search_operator(self, compiler, connection):
212228
params = {
213229
"defaultPath": self.path,
214-
"query": self.query,
230+
"query": self.query.as_mql(compiler, connection, as_path=True),
215231
}
216232
if self.score is not None:
217-
params["score"] = self.score
233+
params["score"] = self.score.definitions
218234
return {"queryString": params}
219235

220236

@@ -228,15 +244,15 @@ def __init__(self, path, lt=None, lte=None, gt=None, gte=None, score=None):
228244
self.score = score
229245
super().__init__()
230246

231-
def get_search_fields(self):
232-
return {self.path}
247+
def get_search_fields(self, compiler, connection):
248+
return {self.path.as_mql(compiler, connection, as_path=True)}
233249

234-
def search_operator(self):
250+
def search_operator(self, compiler, connection):
235251
params = {
236-
"path": self.path,
252+
"path": self.path.as_mql(compiler, connection, as_path=True),
237253
}
238254
if self.score is not None:
239-
params["score"] = self.score
255+
params["score"] = self.score.definitions
240256
if self.lt is not None:
241257
params["lt"] = self.lt
242258
if self.lte is not None:
@@ -256,16 +272,16 @@ def __init__(self, path, query, allow_analyzed_field=None, score=None):
256272
self.score = score
257273
super().__init__()
258274

259-
def get_search_fields(self):
260-
return {self.path}
275+
def get_search_fields(self, compiler, connection):
276+
return {self.path.as_mql(compiler, connection, as_path=True)}
261277

262-
def search_operator(self):
278+
def search_operator(self, compiler, connection):
263279
params = {
264-
"path": self.path,
265-
"query": self.query,
280+
"path": self.path.as_mql(compiler, connection, as_path=True),
281+
"query": self.query.as_mql(compiler, connection, as_path=True),
266282
}
267283
if self.score:
268-
params["score"] = self.score
284+
params["score"] = self.score.definitions
269285
if self.allow_analyzed_field is not None:
270286
params["allowAnalyzedField"] = self.allow_analyzed_field
271287
return {"regex": params}
@@ -281,16 +297,16 @@ def __init__(self, path, query, fuzzy=None, match_criteria=None, synonyms=None,
281297
self.score = score
282298
super().__init__()
283299

284-
def get_search_fields(self):
285-
return {self.path}
300+
def get_search_fields(self, compiler, connection):
301+
return {self.path.as_mql(compiler, connection, as_path=True)}
286302

287-
def search_operator(self):
303+
def search_operator(self, compiler, connection):
288304
params = {
289-
"path": self.path,
290-
"query": self.query,
305+
"path": self.path.as_mql(compiler, connection, as_path=True),
306+
"query": self.query.as_mql(compiler, connection, as_path=True),
291307
}
292308
if self.score:
293-
params["score"] = self.score
309+
params["score"] = self.score.definitions
294310
if self.fuzzy is not None:
295311
params["fuzzy"] = self.fuzzy
296312
if self.match_criteria is not None:
@@ -308,16 +324,16 @@ def __init__(self, path, query, allow_analyzed_field=None, score=None):
308324
self.score = score
309325
super().__init__()
310326

311-
def get_search_fields(self):
312-
return {self.path}
327+
def get_search_fields(self, compiler, connection):
328+
return {self.path.as_mql(compiler, connection, as_path=True)}
313329

314-
def search_operator(self):
330+
def search_operator(self, compiler, connection):
315331
params = {
316-
"path": self.path,
317-
"query": self.query,
332+
"path": self.path.as_mql(compiler, connection, as_path=True),
333+
"query": self.query.as_mql(compiler, connection, as_path=True),
318334
}
319335
if self.score:
320-
params["score"] = self.score
336+
params["score"] = self.score.definitions
321337
if self.allow_analyzed_field is not None:
322338
params["allowAnalyzedField"] = self.allow_analyzed_field
323339
return {"wildcard": params}
@@ -331,17 +347,17 @@ def __init__(self, path, relation, geometry, score=None):
331347
self.score = score
332348
super().__init__()
333349

334-
def get_search_fields(self):
335-
return {self.path}
350+
def get_search_fields(self, compiler, connection):
351+
return {self.path.as_mql(compiler, connection, as_path=True)}
336352

337-
def search_operator(self):
353+
def search_operator(self, compiler, connection):
338354
params = {
339-
"path": self.path,
355+
"path": self.path.as_mql(compiler, connection, as_path=True),
340356
"relation": self.relation,
341357
"geometry": self.geometry,
342358
}
343359
if self.score:
344-
params["score"] = self.score
360+
params["score"] = self.score.definitions
345361
return {"geoShape": params}
346362

347363

@@ -353,17 +369,17 @@ def __init__(self, path, kind, geo_object, score=None):
353369
self.score = score
354370
super().__init__()
355371

356-
def search_operator(self):
372+
def search_operator(self, compiler, connection):
357373
params = {
358-
"path": self.path,
374+
"path": self.path.as_mql(compiler, connection, as_path=True),
359375
self.kind: self.geo_object,
360376
}
361377
if self.score:
362-
params["score"] = self.score
378+
params["score"] = self.score.definitions
363379
return {"geoWithin": params}
364380

365-
def get_search_fields(self):
366-
return {self.path}
381+
def get_search_fields(self, compiler, connection):
382+
return {self.path.as_mql(compiler, connection, as_path=True)}
367383

368384

369385
class SearchMoreLikeThis(SearchExpression):
@@ -372,15 +388,15 @@ def __init__(self, documents, score=None):
372388
self.score = score
373389
super().__init__()
374390

375-
def search_operator(self):
391+
def search_operator(self, compiler, connection):
376392
params = {
377393
"like": self.documents,
378394
}
379395
if self.score:
380-
params["score"] = self.score
396+
params["score"] = self.score.definitions
381397
return {"moreLikeThis": params}
382398

383-
def get_search_fields(self):
399+
def get_search_fields(self, compiler, connection):
384400
needed_fields = set()
385401
for doc in self.documents:
386402
needed_fields.update(set(doc.keys()))
@@ -404,13 +420,13 @@ def __init__(
404420
self.score = score
405421
self.minimum_should_match = minimum_should_match
406422

407-
def get_search_fields(self):
423+
def get_search_fields(self, compiler, connection):
408424
fields = set()
409425
for clause in self.must + self.should + self.filter + self.must_not:
410426
fields.update(clause.get_search_fields())
411427
return fields
412428

413-
def search_operator(self):
429+
def search_operator(self, compiler, connection):
414430
params = {}
415431
if self.must:
416432
params["must"] = [clause.search_operator() for clause in self.must]
@@ -491,8 +507,8 @@ def __or__(self, other):
491507
def __ror__(self, other):
492508
raise NotSupportedError("SearchVector cannot be combined")
493509

494-
def get_search_fields(self):
495-
return {self.path}
510+
def get_search_fields(self, compiler, connection):
511+
return {self.path.as_mql(compiler, connection, as_path=True)}
496512

497513
def _get_query_index(self, fields, compiler):
498514
for search_indexes in compiler.collection.list_search_indexes():
@@ -507,7 +523,7 @@ def _get_query_index(self, fields, compiler):
507523
def as_mql(self, compiler, connection):
508524
params = {
509525
"index": self._get_query_index(self.get_search_fields(), compiler),
510-
"path": self.path,
526+
"path": self.path.as_mql(compiler, connection, as_path=True),
511527
"queryVector": self.query_vector,
512528
"limit": self.limit,
513529
}

0 commit comments

Comments
 (0)