Skip to content

Commit 7e2604c

Browse files
committed
Add Path combinable.
1 parent 690dec9 commit 7e2604c

File tree

4 files changed

+58
-48
lines changed

4 files changed

+58
-48
lines changed

django_mongodb_backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def _compound_searches_queries(self, search_replacements):
302302
search.as_mql(self, self.connection),
303303
{
304304
"$addFields": {
305-
result_col.as_mql(self, self.connection, as_path=True): {
305+
result_col.as_mql(self, self.connection).removeprefix("$"): {
306306
"$meta": score_function
307307
}
308308
}

django_mongodb_backend/expressions/builtins.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from bson import Decimal128
66
from django.core.exceptions import EmptyResultSet, FullResultSet
77
from django.db import NotSupportedError
8+
from django.db.models import F
89
from django.db.models.expressions import (
910
Case,
1011
Col,
@@ -53,7 +54,7 @@ def case(self, compiler, connection):
5354
}
5455

5556

56-
def col(self, compiler, connection, as_path=False): # noqa: ARG001
57+
def col(self, compiler, connection): # noqa: ARG001
5758
# If the column is part of a subquery and belongs to one of the parent
5859
# queries, it will be stored for reference using $let in a $lookup stage.
5960
# If the query is built with `alias_cols=False`, treat the column as
@@ -71,7 +72,7 @@ def col(self, compiler, connection, as_path=False): # noqa: ARG001
7172
# Add the column's collection's alias for columns in joined collections.
7273
has_alias = self.alias and self.alias != compiler.collection_name
7374
prefix = f"{self.alias}." if has_alias else ""
74-
if not as_path:
75+
if not getattr(self, "_as_path", False):
7576
prefix = f"${prefix}"
7677
return f"{prefix}{self.target.column}"
7778

@@ -209,6 +210,13 @@ def value(self, compiler, connection): # noqa: ARG001
209210
return value
210211

211212

213+
class Path(F):
214+
def resolve_expression(self, *args, **kwargs):
215+
expr = super().resolve_expression(*args, **kwargs)
216+
expr._as_path = True
217+
return expr
218+
219+
212220
def register_expressions():
213221
Case.as_mql = case
214222
Col.as_mql = col

django_mongodb_backend/expressions/search.py

Lines changed: 44 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from django.db import NotSupportedError
22
from django.db.models import CharField, Expression, FloatField, TextField
3-
from django.db.models.expressions import F, Value
3+
from django.db.models.expressions import Value
44
from django.db.models.lookups import Lookup
55

66
from ..query_utils import process_lhs, process_rhs
7+
from .builtins import Path
78

89

9-
def cast_as_field(path):
10-
return F(path) if isinstance(path, str) else path
10+
def cast_as_path(path):
11+
return Path(path) if isinstance(path, str) else path
1112

1213

1314
class Operator:
@@ -146,19 +147,19 @@ class SearchAutocomplete(SearchExpression):
146147
"""
147148

148149
def __init__(self, path, query, *, fuzzy=None, token_order=None, score=None):
149-
self.path = cast_as_field(path)
150+
self.path = cast_as_path(path)
150151
self.query = query
151152
self.fuzzy = fuzzy
152153
self.token_order = token_order
153154
self.score = score
154155
super().__init__()
155156

156157
def get_search_fields(self, compiler, connection):
157-
return {self.path.as_mql(compiler, connection, as_path=True)}
158+
return {self.path.as_mql(compiler, connection)}
158159

159160
def search_operator(self, compiler, connection):
160161
params = {
161-
"path": self.path.as_mql(compiler, connection, as_path=True),
162+
"path": self.path.as_mql(compiler, connection),
162163
"query": self.query,
163164
}
164165
if self.score:
@@ -186,17 +187,17 @@ class SearchEquals(SearchExpression):
186187
"""
187188

188189
def __init__(self, path, value, *, score=None):
189-
self.path = cast_as_field(path)
190+
self.path = cast_as_path(path)
190191
self.value = value
191192
self.score = score
192193
super().__init__()
193194

194195
def get_search_fields(self, compiler, connection):
195-
return {self.path.as_mql(compiler, connection, as_path=True)}
196+
return {self.path.as_mql(compiler, connection)}
196197

197198
def search_operator(self, compiler, connection):
198199
params = {
199-
"path": self.path.as_mql(compiler, connection, as_path=True),
200+
"path": self.path.as_mql(compiler, connection),
200201
"value": self.value,
201202
}
202203
if self.score:
@@ -223,16 +224,16 @@ class SearchExists(SearchExpression):
223224
"""
224225

225226
def __init__(self, path, *, score=None):
226-
self.path = cast_as_field(path)
227+
self.path = cast_as_path(path)
227228
self.score = score
228229
super().__init__()
229230

230231
def get_search_fields(self, compiler, connection):
231-
return {self.path.as_mql(compiler, connection, as_path=True)}
232+
return {self.path.as_mql(compiler, connection)}
232233

233234
def search_operator(self, compiler, connection):
234235
params = {
235-
"path": self.path.as_mql(compiler, connection, as_path=True),
236+
"path": self.path.as_mql(compiler, connection),
236237
}
237238
if self.score:
238239
params["score"] = self.score.as_mql(compiler, connection)
@@ -255,17 +256,17 @@ class SearchIn(SearchExpression):
255256
"""
256257

257258
def __init__(self, path, value, *, score=None):
258-
self.path = cast_as_field(path)
259+
self.path = cast_as_path(path)
259260
self.value = value
260261
self.score = score
261262
super().__init__()
262263

263264
def get_search_fields(self, compiler, connection):
264-
return {self.path.as_mql(compiler, connection, as_path=True)}
265+
return {self.path.as_mql(compiler, connection)}
265266

266267
def search_operator(self, compiler, connection):
267268
params = {
268-
"path": self.path.as_mql(compiler, connection, as_path=True),
269+
"path": self.path.as_mql(compiler, connection),
269270
"value": self.value,
270271
}
271272
if self.score:
@@ -294,19 +295,19 @@ class SearchPhrase(SearchExpression):
294295
"""
295296

296297
def __init__(self, path, query, *, slop=None, synonyms=None, score=None):
297-
self.path = cast_as_field(path)
298+
self.path = cast_as_path(path)
298299
self.query = query
299300
self.slop = slop
300301
self.synonyms = synonyms
301302
self.score = score
302303
super().__init__()
303304

304305
def get_search_fields(self, compiler, connection):
305-
return {self.path.as_mql(compiler, connection, as_path=True)}
306+
return {self.path.as_mql(compiler, connection)}
306307

307308
def search_operator(self, compiler, connection):
308309
params = {
309-
"path": self.path.as_mql(compiler, connection, as_path=True),
310+
"path": self.path.as_mql(compiler, connection),
310311
"query": self.query,
311312
}
312313
if self.score:
@@ -338,17 +339,17 @@ class SearchQueryString(SearchExpression):
338339
"""
339340

340341
def __init__(self, path, query, *, score=None):
341-
self.path = cast_as_field(path)
342+
self.path = cast_as_path(path)
342343
self.query = query
343344
self.score = score
344345
super().__init__()
345346

346347
def get_search_fields(self, compiler, connection):
347-
return {self.path.as_mql(compiler, connection, as_path=True)}
348+
return {self.path.as_mql(compiler, connection)}
348349

349350
def search_operator(self, compiler, connection):
350351
params = {
351-
"defaultPath": self.path.as_mql(compiler, connection, as_path=True),
352+
"defaultPath": self.path.as_mql(compiler, connection),
352353
"query": self.query,
353354
}
354355
if self.score:
@@ -378,7 +379,7 @@ class SearchRange(SearchExpression):
378379
"""
379380

380381
def __init__(self, path, *, lt=None, lte=None, gt=None, gte=None, score=None):
381-
self.path = cast_as_field(path)
382+
self.path = cast_as_path(path)
382383
self.lt = lt
383384
self.lte = lte
384385
self.gt = gt
@@ -387,11 +388,11 @@ def __init__(self, path, *, lt=None, lte=None, gt=None, gte=None, score=None):
387388
super().__init__()
388389

389390
def get_search_fields(self, compiler, connection):
390-
return {self.path.as_mql(compiler, connection, as_path=True)}
391+
return {self.path.as_mql(compiler, connection)}
391392

392393
def search_operator(self, compiler, connection):
393394
params = {
394-
"path": self.path.as_mql(compiler, connection, as_path=True),
395+
"path": self.path.as_mql(compiler, connection),
395396
}
396397
if self.score:
397398
params["score"] = self.score.as_mql(compiler, connection)
@@ -424,18 +425,18 @@ class SearchRegex(SearchExpression):
424425
"""
425426

426427
def __init__(self, path, query, *, allow_analyzed_field=None, score=None):
427-
self.path = cast_as_field(path)
428+
self.path = cast_as_path(path)
428429
self.query = query
429430
self.allow_analyzed_field = allow_analyzed_field
430431
self.score = score
431432
super().__init__()
432433

433434
def get_search_fields(self, compiler, connection):
434-
return {self.path.as_mql(compiler, connection, as_path=True)}
435+
return {self.path.as_mql(compiler, connection)}
435436

436437
def search_operator(self, compiler, connection):
437438
params = {
438-
"path": self.path.as_mql(compiler, connection, as_path=True),
439+
"path": self.path.as_mql(compiler, connection),
439440
"query": self.query,
440441
}
441442
if self.score:
@@ -472,7 +473,7 @@ class SearchText(SearchExpression):
472473
"""
473474

474475
def __init__(self, path, query, *, fuzzy=None, match_criteria=None, synonyms=None, score=None):
475-
self.path = cast_as_field(path)
476+
self.path = cast_as_path(path)
476477
self.query = query
477478
self.fuzzy = fuzzy
478479
self.match_criteria = match_criteria
@@ -481,11 +482,11 @@ def __init__(self, path, query, *, fuzzy=None, match_criteria=None, synonyms=Non
481482
super().__init__()
482483

483484
def get_search_fields(self, compiler, connection):
484-
return {self.path.as_mql(compiler, connection, as_path=True)}
485+
return {self.path.as_mql(compiler, connection)}
485486

486487
def search_operator(self, compiler, connection):
487488
params = {
488-
"path": self.path.as_mql(compiler, connection, as_path=True),
489+
"path": self.path.as_mql(compiler, connection),
489490
"query": self.query,
490491
}
491492
if self.score:
@@ -520,18 +521,18 @@ class SearchWildcard(SearchExpression):
520521
"""
521522

522523
def __init__(self, path, query, allow_analyzed_field=None, score=None):
523-
self.path = cast_as_field(path)
524+
self.path = cast_as_path(path)
524525
self.query = query
525526
self.allow_analyzed_field = allow_analyzed_field
526527
self.score = score
527528
super().__init__()
528529

529530
def get_search_fields(self, compiler, connection):
530-
return {self.path.as_mql(compiler, connection, as_path=True)}
531+
return {self.path.as_mql(compiler, connection)}
531532

532533
def search_operator(self, compiler, connection):
533534
params = {
534-
"path": self.path.as_mql(compiler, connection, as_path=True),
535+
"path": self.path.as_mql(compiler, connection),
535536
"query": self.query,
536537
}
537538
if self.score:
@@ -566,18 +567,18 @@ class SearchGeoShape(SearchExpression):
566567
"""
567568

568569
def __init__(self, path, relation, geometry, *, score=None):
569-
self.path = cast_as_field(path)
570+
self.path = cast_as_path(path)
570571
self.relation = relation
571572
self.geometry = geometry
572573
self.score = score
573574
super().__init__()
574575

575576
def get_search_fields(self, compiler, connection):
576-
return {self.path.as_mql(compiler, connection, as_path=True)}
577+
return {self.path.as_mql(compiler, connection)}
577578

578579
def search_operator(self, compiler, connection):
579580
params = {
580-
"path": self.path.as_mql(compiler, connection, as_path=True),
581+
"path": self.path.as_mql(compiler, connection),
581582
"relation": self.relation,
582583
"geometry": self.geometry,
583584
}
@@ -610,18 +611,18 @@ class SearchGeoWithin(SearchExpression):
610611
"""
611612

612613
def __init__(self, path, kind, geometry, *, score=None):
613-
self.path = cast_as_field(path)
614+
self.path = cast_as_path(path)
614615
self.kind = kind
615616
self.geometry = geometry
616617
self.score = score
617618
super().__init__()
618619

619620
def get_search_fields(self, compiler, connection):
620-
return {self.path.as_mql(compiler, connection, as_path=True)}
621+
return {self.path.as_mql(compiler, connection)}
621622

622623
def search_operator(self, compiler, connection):
623624
params = {
624-
"path": self.path.as_mql(compiler, connection, as_path=True),
625+
"path": self.path.as_mql(compiler, connection),
625626
self.kind: self.geometry,
626627
}
627628
if self.score:
@@ -855,7 +856,7 @@ def __init__(
855856
exact=None,
856857
filter=None,
857858
):
858-
self.path = cast_as_field(path)
859+
self.path = cast_as_path(path)
859860
self.query_vector = query_vector
860861
self.limit = limit
861862
self.num_candidates = num_candidates
@@ -879,7 +880,7 @@ def __ror__(self, other):
879880
raise NotSupportedError("SearchVector cannot be combined")
880881

881882
def get_search_fields(self, compiler, connection):
882-
return {self.path.as_mql(compiler, connection, as_path=True)}
883+
return {self.path.as_mql(compiler, connection)}
883884

884885
def _get_query_index(self, fields, compiler):
885886
for search_indexes in compiler.collection.list_search_indexes():
@@ -894,7 +895,7 @@ def _get_query_index(self, fields, compiler):
894895
def as_mql(self, compiler, connection):
895896
params = {
896897
"index": self._get_query_index(self.get_search_fields(compiler, connection), compiler),
897-
"path": self.path.as_mql(compiler, connection, as_path=True),
898+
"path": self.path.as_mql(compiler, connection),
898899
"queryVector": self.query_vector,
899900
"limit": self.limit,
900901
}
@@ -924,6 +925,7 @@ class SearchTextLookup(Lookup):
924925

925926
def __init__(self, lhs, rhs):
926927
super().__init__(lhs, rhs)
928+
self.lhs._as_path = True
927929
self.lhs = SearchText(self.lhs, self.rhs)
928930
self.rhs = Value(0)
929931

django_mongodb_backend/fields/embedded_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,14 +184,14 @@ def get_transform(self, name):
184184
f"{suggestion}"
185185
)
186186

187-
def as_mql(self, compiler, connection, as_path=False):
187+
def as_mql(self, compiler, connection):
188188
previous = self
189189
key_transforms = []
190190
while isinstance(previous, KeyTransform):
191191
key_transforms.insert(0, previous.key_name)
192192
previous = previous.lhs
193-
if as_path:
194-
mql = previous.as_mql(compiler, connection, as_path=True)
193+
if getattr(self, "_as_path", False):
194+
mql = previous.as_mql(compiler, connection).removeprefix("$")
195195
mql_path = ".".join(key_transforms)
196196
return f"{mql}.{mql_path}"
197197
mql = previous.as_mql(compiler, connection)

0 commit comments

Comments
 (0)