Skip to content

Commit afa3710

Browse files
committed
Add combined expressions test.
1 parent 921008f commit afa3710

File tree

1 file changed

+40
-41
lines changed

1 file changed

+40
-41
lines changed

django_mongodb_backend/expressions/builtins.py

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,12 @@ def negate(self):
231231
def __hash__(self):
232232
return hash(self.operator)
233233

234+
def __str__(self):
235+
return self.operator
236+
237+
def __repr__(self):
238+
return self.operator
239+
234240

235241
class SearchCombinable:
236242
def _combine(self, other, connector):
@@ -283,12 +289,12 @@ def _get_query_index(self, fields, compiler):
283289
return search_indexes["name"]
284290
return "default"
285291

286-
def search_operator(self, compiler, connection):
292+
def search_operator(self):
287293
raise NotImplementedError
288294

289295
def as_mql(self, compiler, connection):
290296
index = self._get_query_index(self.get_search_fields(), compiler)
291-
return {"$search": {**self.search_operator(compiler, connection), "index": index}}
297+
return {"$search": {**self.search_operator(), "index": index}}
292298

293299

294300
class SearchAutocomplete(SearchExpression):
@@ -302,7 +308,7 @@ def __init__(self, path, query, fuzzy=None, score=None):
302308
def get_search_fields(self):
303309
return {self.path}
304310

305-
def search_operator(self, compiler, connection):
311+
def search_operator(self):
306312
params = {
307313
"path": self.path,
308314
"query": self.query,
@@ -324,7 +330,7 @@ def __init__(self, path, value, score=None):
324330
def get_search_fields(self):
325331
return {self.path}
326332

327-
def search_operator(self, compiler, connection):
333+
def search_operator(self):
328334
params = {
329335
"path": self.path,
330336
"value": self.value,
@@ -343,7 +349,7 @@ def __init__(self, path, score=None):
343349
def get_search_fields(self):
344350
return {self.path}
345351

346-
def search_operator(self, compiler, connection):
352+
def search_operator(self):
347353
params = {
348354
"path": self.path,
349355
}
@@ -362,7 +368,7 @@ def __init__(self, path, value, score=None):
362368
def get_search_fields(self):
363369
return {self.path}
364370

365-
def search_operator(self, compiler, connection):
371+
def search_operator(self):
366372
params = {
367373
"path": self.path,
368374
"value": self.value,
@@ -384,7 +390,7 @@ def __init__(self, path, query, slop=None, synonyms=None, score=None):
384390
def get_search_fields(self):
385391
return {self.path}
386392

387-
def search_operator(self, compiler, connection):
393+
def search_operator(self):
388394
params = {
389395
"path": self.path,
390396
"query": self.query,
@@ -408,7 +414,7 @@ def __init__(self, path, query, score=None):
408414
def get_search_fields(self):
409415
return {self.path}
410416

411-
def search_operator(self, compiler, connection):
417+
def search_operator(self):
412418
params = {
413419
"defaultPath": self.path,
414420
"query": self.query,
@@ -431,7 +437,7 @@ def __init__(self, path, lt=None, lte=None, gt=None, gte=None, score=None):
431437
def get_search_fields(self):
432438
return {self.path}
433439

434-
def search_operator(self, compiler, connection):
440+
def search_operator(self):
435441
params = {
436442
"path": self.path,
437443
}
@@ -459,7 +465,7 @@ def __init__(self, path, query, allow_analyzed_field=None, score=None):
459465
def get_search_fields(self):
460466
return {self.path}
461467

462-
def search_operator(self, compiler, connection):
468+
def search_operator(self):
463469
params = {
464470
"path": self.path,
465471
"query": self.query,
@@ -484,7 +490,7 @@ def __init__(self, path, query, fuzzy=None, match_criteria=None, synonyms=None,
484490
def get_search_fields(self):
485491
return {self.path}
486492

487-
def search_operator(self, compiler, connection):
493+
def search_operator(self):
488494
params = {
489495
"path": self.path,
490496
"query": self.query,
@@ -511,7 +517,7 @@ def __init__(self, path, query, allow_analyzed_field=None, score=None):
511517
def get_search_fields(self):
512518
return {self.path}
513519

514-
def search_operator(self, compiler, connection):
520+
def search_operator(self):
515521
params = {
516522
"path": self.path,
517523
"query": self.query,
@@ -534,7 +540,7 @@ def __init__(self, path, relation, geometry, score=None):
534540
def get_search_fields(self):
535541
return {self.path}
536542

537-
def search_operator(self, compiler, connection):
543+
def search_operator(self):
538544
params = {
539545
"path": self.path,
540546
"relation": self.relation,
@@ -553,7 +559,7 @@ def __init__(self, path, kind, geo_object, score=None):
553559
self.score = score
554560
super().__init__()
555561

556-
def search_operator(self, compiler, connection):
562+
def search_operator(self):
557563
params = {
558564
"path": self.path,
559565
self.kind: self.geo_object,
@@ -572,7 +578,7 @@ def __init__(self, documents, score=None):
572578
self.score = score
573579
super().__init__()
574580

575-
def search_operator(self, compiler, connection):
581+
def search_operator(self):
576582
params = {
577583
"like": self.documents,
578584
}
@@ -665,29 +671,23 @@ def get_search_fields(self):
665671
fields.update(clause.get_search_fields())
666672
return fields
667673

668-
def search_operator(self, compiler, connection):
674+
def search_operator(self):
669675
params = {}
670676
if self.must:
671-
params["must"] = [clause.search_operator(compiler, connection) for clause in self.must]
677+
params["must"] = [clause.search_operator() for clause in self.must]
672678
if self.must_not:
673-
params["mustNot"] = [
674-
clause.search_operator(compiler, connection) for clause in self.must_not
675-
]
679+
params["mustNot"] = [clause.search_operator() for clause in self.must_not]
676680
if self.should:
677-
params["should"] = [
678-
clause.search_operator(compiler, connection) for clause in self.should
679-
]
681+
params["should"] = [clause.search_operator() for clause in self.should]
680682
if self.filter:
681-
params["filter"] = [
682-
clause.search_operator(compiler, connection) for clause in self.filter
683-
]
683+
params["filter"] = [clause.search_operator() for clause in self.filter]
684684
if self.minimum_should_match is not None:
685685
params["minimumShouldMatch"] = self.minimum_should_match
686686

687687
return {"compound": params}
688688

689689
def negate(self):
690-
return CompoundExpression(must=self.must_not, must_not=self.must + self.filter)
690+
return CompoundExpression(must_not=[self])
691691

692692

693693
class CombinedSearchExpression(SearchExpression):
@@ -697,7 +697,7 @@ def __init__(self, lhs, operator, rhs):
697697
self.rhs = rhs
698698

699699
@staticmethod
700-
def _flatten(node, negated=False):
700+
def resolve(node, negated=False):
701701
if node is None:
702702
return None
703703
# Leaf, resolve the compoundExpression
@@ -706,25 +706,24 @@ def _flatten(node, negated=False):
706706
# Apply De Morgan's Laws.
707707
operator = node.operator.negate() if negated else node.operator
708708
negated = negated != (node.operator == Operator.NOT)
709-
lhs_compound = node._flatten(node.lhs, negated)
710-
rhs_compound = node._flatten(node.rhs, negated)
709+
lhs_compound = node.resolve(node.lhs, negated)
710+
rhs_compound = node.resolve(node.rhs, negated)
711711
if operator == Operator.OR:
712712
return CompoundExpression(should=[lhs_compound, rhs_compound], minimum_should_match=1)
713-
if node.operator == Operator.AND:
714-
return CompoundExpression(
715-
must=lhs_compound.must + rhs_compound.must,
716-
must_not=lhs_compound.must_not + rhs_compound.must_not,
717-
should=lhs_compound.should + rhs_compound.should,
718-
filter=lhs_compound.filter + rhs_compound.filter,
719-
)
720-
# it also can be written as:
721-
# this way is more consistent with OR, but the above is shorter in the debug query.
722-
# return CompoundExpression(must=[lhs_compound, rhs_compound])
713+
if operator == Operator.AND:
714+
# NOTE: we can't just do the code below, think about this case (A | B) & (C | D)
715+
# return CompoundExpression(
716+
# must=lhs_compound.must + rhs_compound.must,
717+
# must_not=lhs_compound.must_not + rhs_compound.must_not,
718+
# should=lhs_compound.should + rhs_compound.should,
719+
# filter=lhs_compound.filter + rhs_compound.filter,
720+
# )
721+
return CompoundExpression(must=[lhs_compound, rhs_compound])
723722
# not operator
724723
return lhs_compound
725724

726725
def as_mql(self, compiler, connection):
727-
expression = self._flatten(self)
726+
expression = self.resolve(self)
728727
return expression.as_mql(compiler, connection)
729728

730729

0 commit comments

Comments
 (0)