Skip to content

Commit 165bda8

Browse files
committed
Add combined expressions test.
1 parent 3e1f584 commit 165bda8

File tree

1 file changed

+40
-41
lines changed

1 file changed

+40
-41
lines changed

django_mongodb_backend/expressions/builtins.py

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,12 @@ def negate(self):
236236
def __hash__(self):
237237
return hash(self.operator)
238238

239+
def __str__(self):
240+
return self.operator
241+
242+
def __repr__(self):
243+
return self.operator
244+
239245

240246
class SearchCombinable:
241247
def _combine(self, other, connector):
@@ -288,12 +294,12 @@ def _get_query_index(self, fields, compiler):
288294
return search_indexes["name"]
289295
return "default"
290296

291-
def search_operator(self, compiler, connection):
297+
def search_operator(self):
292298
raise NotImplementedError
293299

294300
def as_mql(self, compiler, connection):
295301
index = self._get_query_index(self.get_search_fields(), compiler)
296-
return {"$search": {**self.search_operator(compiler, connection), "index": index}}
302+
return {"$search": {**self.search_operator(), "index": index}}
297303

298304

299305
class SearchAutocomplete(SearchExpression):
@@ -307,7 +313,7 @@ def __init__(self, path, query, fuzzy=None, score=None):
307313
def get_search_fields(self):
308314
return {self.path}
309315

310-
def search_operator(self, compiler, connection):
316+
def search_operator(self):
311317
params = {
312318
"path": self.path,
313319
"query": self.query,
@@ -329,7 +335,7 @@ def __init__(self, path, value, score=None):
329335
def get_search_fields(self):
330336
return {self.path}
331337

332-
def search_operator(self, compiler, connection):
338+
def search_operator(self):
333339
params = {
334340
"path": self.path,
335341
"value": self.value,
@@ -348,7 +354,7 @@ def __init__(self, path, score=None):
348354
def get_search_fields(self):
349355
return {self.path}
350356

351-
def search_operator(self, compiler, connection):
357+
def search_operator(self):
352358
params = {
353359
"path": self.path,
354360
}
@@ -367,7 +373,7 @@ def __init__(self, path, value, score=None):
367373
def get_search_fields(self):
368374
return {self.path}
369375

370-
def search_operator(self, compiler, connection):
376+
def search_operator(self):
371377
params = {
372378
"path": self.path,
373379
"value": self.value,
@@ -389,7 +395,7 @@ def __init__(self, path, query, slop=None, synonyms=None, score=None):
389395
def get_search_fields(self):
390396
return {self.path}
391397

392-
def search_operator(self, compiler, connection):
398+
def search_operator(self):
393399
params = {
394400
"path": self.path,
395401
"query": self.query,
@@ -413,7 +419,7 @@ def __init__(self, path, query, score=None):
413419
def get_search_fields(self):
414420
return {self.path}
415421

416-
def search_operator(self, compiler, connection):
422+
def search_operator(self):
417423
params = {
418424
"defaultPath": self.path,
419425
"query": self.query,
@@ -436,7 +442,7 @@ def __init__(self, path, lt=None, lte=None, gt=None, gte=None, score=None):
436442
def get_search_fields(self):
437443
return {self.path}
438444

439-
def search_operator(self, compiler, connection):
445+
def search_operator(self):
440446
params = {
441447
"path": self.path,
442448
}
@@ -464,7 +470,7 @@ def __init__(self, path, query, allow_analyzed_field=None, score=None):
464470
def get_search_fields(self):
465471
return {self.path}
466472

467-
def search_operator(self, compiler, connection):
473+
def search_operator(self):
468474
params = {
469475
"path": self.path,
470476
"query": self.query,
@@ -489,7 +495,7 @@ def __init__(self, path, query, fuzzy=None, match_criteria=None, synonyms=None,
489495
def get_search_fields(self):
490496
return {self.path}
491497

492-
def search_operator(self, compiler, connection):
498+
def search_operator(self):
493499
params = {
494500
"path": self.path,
495501
"query": self.query,
@@ -516,7 +522,7 @@ def __init__(self, path, query, allow_analyzed_field=None, score=None):
516522
def get_search_fields(self):
517523
return {self.path}
518524

519-
def search_operator(self, compiler, connection):
525+
def search_operator(self):
520526
params = {
521527
"path": self.path,
522528
"query": self.query,
@@ -539,7 +545,7 @@ def __init__(self, path, relation, geometry, score=None):
539545
def get_search_fields(self):
540546
return {self.path}
541547

542-
def search_operator(self, compiler, connection):
548+
def search_operator(self):
543549
params = {
544550
"path": self.path,
545551
"relation": self.relation,
@@ -558,7 +564,7 @@ def __init__(self, path, kind, geo_object, score=None):
558564
self.score = score
559565
super().__init__()
560566

561-
def search_operator(self, compiler, connection):
567+
def search_operator(self):
562568
params = {
563569
"path": self.path,
564570
self.kind: self.geo_object,
@@ -577,7 +583,7 @@ def __init__(self, documents, score=None):
577583
self.score = score
578584
super().__init__()
579585

580-
def search_operator(self, compiler, connection):
586+
def search_operator(self):
581587
params = {
582588
"like": self.documents,
583589
}
@@ -670,29 +676,23 @@ def get_search_fields(self):
670676
fields.update(clause.get_search_fields())
671677
return fields
672678

673-
def search_operator(self, compiler, connection):
679+
def search_operator(self):
674680
params = {}
675681
if self.must:
676-
params["must"] = [clause.search_operator(compiler, connection) for clause in self.must]
682+
params["must"] = [clause.search_operator() for clause in self.must]
677683
if self.must_not:
678-
params["mustNot"] = [
679-
clause.search_operator(compiler, connection) for clause in self.must_not
680-
]
684+
params["mustNot"] = [clause.search_operator() for clause in self.must_not]
681685
if self.should:
682-
params["should"] = [
683-
clause.search_operator(compiler, connection) for clause in self.should
684-
]
686+
params["should"] = [clause.search_operator() for clause in self.should]
685687
if self.filter:
686-
params["filter"] = [
687-
clause.search_operator(compiler, connection) for clause in self.filter
688-
]
688+
params["filter"] = [clause.search_operator() for clause in self.filter]
689689
if self.minimum_should_match is not None:
690690
params["minimumShouldMatch"] = self.minimum_should_match
691691

692692
return {"compound": params}
693693

694694
def negate(self):
695-
return CompoundExpression(must=self.must_not, must_not=self.must + self.filter)
695+
return CompoundExpression(must_not=[self])
696696

697697

698698
class CombinedSearchExpression(SearchExpression):
@@ -702,7 +702,7 @@ def __init__(self, lhs, operator, rhs):
702702
self.rhs = rhs
703703

704704
@staticmethod
705-
def _flatten(node, negated=False):
705+
def resolve(node, negated=False):
706706
if node is None:
707707
return None
708708
# Leaf, resolve the compoundExpression
@@ -711,25 +711,24 @@ def _flatten(node, negated=False):
711711
# Apply De Morgan's Laws.
712712
operator = node.operator.negate() if negated else node.operator
713713
negated = negated != (node.operator == Operator.NOT)
714-
lhs_compound = node._flatten(node.lhs, negated)
715-
rhs_compound = node._flatten(node.rhs, negated)
714+
lhs_compound = node.resolve(node.lhs, negated)
715+
rhs_compound = node.resolve(node.rhs, negated)
716716
if operator == Operator.OR:
717717
return CompoundExpression(should=[lhs_compound, rhs_compound], minimum_should_match=1)
718-
if node.operator == Operator.AND:
719-
return CompoundExpression(
720-
must=lhs_compound.must + rhs_compound.must,
721-
must_not=lhs_compound.must_not + rhs_compound.must_not,
722-
should=lhs_compound.should + rhs_compound.should,
723-
filter=lhs_compound.filter + rhs_compound.filter,
724-
)
725-
# it also can be written as:
726-
# this way is more consistent with OR, but the above is shorter in the debug query.
727-
# return CompoundExpression(must=[lhs_compound, rhs_compound])
718+
if operator == Operator.AND:
719+
# NOTE: we can't just do the code below, think about this case (A | B) & (C | D)
720+
# return CompoundExpression(
721+
# must=lhs_compound.must + rhs_compound.must,
722+
# must_not=lhs_compound.must_not + rhs_compound.must_not,
723+
# should=lhs_compound.should + rhs_compound.should,
724+
# filter=lhs_compound.filter + rhs_compound.filter,
725+
# )
726+
return CompoundExpression(must=[lhs_compound, rhs_compound])
728727
# not operator
729728
return lhs_compound
730729

731730
def as_mql(self, compiler, connection):
732-
expression = self._flatten(self)
731+
expression = self.resolve(self)
733732
return expression.as_mql(compiler, connection)
734733

735734

0 commit comments

Comments
 (0)