diff --git a/django_mongodb_backend/expressions/search.py b/django_mongodb_backend/expressions/search.py index 19fbe137..0cbfe788 100644 --- a/django_mongodb_backend/expressions/search.py +++ b/django_mongodb_backend/expressions/search.py @@ -6,12 +6,6 @@ from ..query_utils import process_lhs, process_rhs -def cast_as_value(value): - if value is None: - return None - return Value(value) if not hasattr(value, "resolve_expression") else value - - def cast_as_field(path): return F(path) if isinstance(path, str) else path @@ -96,6 +90,12 @@ def __repr__(self): def as_sql(self, compiler, connection): return "", [] + def get_source_expressions(self): + return [self.path] + + def set_source_expressions(self, exprs): + (self.path,) = exprs + def _get_indexed_fields(self, mappings): if isinstance(mappings, list): for definition in mappings: @@ -147,32 +147,26 @@ class SearchAutocomplete(SearchExpression): def __init__(self, path, query, *, fuzzy=None, token_order=None, score=None): self.path = cast_as_field(path) - self.query = cast_as_value(query) - self.fuzzy = cast_as_value(fuzzy) - self.token_order = cast_as_value(token_order) + self.query = query + self.fuzzy = fuzzy + self.token_order = token_order self.score = score super().__init__() - def get_source_expressions(self): - return [self.path, self.query, self.fuzzy, self.token_order] - - def set_source_expressions(self, exprs): - self.path, self.query, self.fuzzy, self.token_order = exprs - def get_search_fields(self, compiler, connection): return {self.path.as_mql(compiler, connection, as_path=True)} def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), - "query": self.query.value, + "query": self.query, } if self.score: params["score"] = self.score.as_mql(compiler, connection) if self.fuzzy is not None: - params["fuzzy"] = self.fuzzy.value + params["fuzzy"] = self.fuzzy if self.token_order: - params["tokenOrder"] = self.token_order.value + params["tokenOrder"] = self.token_order return {"autocomplete": params} @@ -193,23 +187,17 @@ class SearchEquals(SearchExpression): def __init__(self, path, value, *, score=None): self.path = cast_as_field(path) - self.value = cast_as_value(value) + self.value = value self.score = score super().__init__() def get_search_fields(self, compiler, connection): return {self.path.as_mql(compiler, connection, as_path=True)} - def get_source_expressions(self): - return [self.path, self.value] - - def set_source_expressions(self, exprs): - self.path, self.value = exprs - def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), - "value": self.value.value, + "value": self.value, } if self.score: params["score"] = self.score.as_mql(compiler, connection) @@ -242,12 +230,6 @@ def __init__(self, path, *, score=None): def get_search_fields(self, compiler, connection): return {self.path.as_mql(compiler, connection, as_path=True)} - def get_source_expressions(self): - return [self.path] - - def set_source_expressions(self, exprs): - (self.path,) = exprs - def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), @@ -274,23 +256,17 @@ class SearchIn(SearchExpression): def __init__(self, path, value, *, score=None): self.path = cast_as_field(path) - self.value = cast_as_value(value) + self.value = value self.score = score super().__init__() def get_search_fields(self, compiler, connection): return {self.path.as_mql(compiler, connection, as_path=True)} - def get_source_expressions(self): - return [self.path, self.value] - - def set_source_expressions(self, exprs): - self.path, self.value = exprs - def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), - "value": self.value.value, + "value": self.value, } if self.score: params["score"] = self.score.as_mql(compiler, connection) @@ -319,32 +295,26 @@ class SearchPhrase(SearchExpression): def __init__(self, path, query, *, slop=None, synonyms=None, score=None): self.path = cast_as_field(path) - self.query = cast_as_value(query) - self.slop = cast_as_value(slop) - self.synonyms = cast_as_value(synonyms) + self.query = query + self.slop = slop + self.synonyms = synonyms self.score = score super().__init__() def get_search_fields(self, compiler, connection): return {self.path.as_mql(compiler, connection, as_path=True)} - def get_source_expressions(self): - return [self.path, self.query, self.slop, self.synonyms] - - def set_source_expressions(self, exprs): - self.path, self.query, self.slop, self.synonyms = exprs - def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), - "query": self.query.value, + "query": self.query, } if self.score: params["score"] = self.score.as_mql(compiler, connection) if self.slop: - params["slop"] = self.slop.value + params["slop"] = self.slop if self.synonyms: - params["synonyms"] = self.synonyms.value + params["synonyms"] = self.synonyms return {"phrase": params} @@ -369,23 +339,17 @@ class SearchQueryString(SearchExpression): def __init__(self, path, query, *, score=None): self.path = cast_as_field(path) - self.query = cast_as_value(query) + self.query = query self.score = score super().__init__() def get_search_fields(self, compiler, connection): return {self.path.as_mql(compiler, connection, as_path=True)} - def get_source_expressions(self): - return [self.path, self.query] - - def set_source_expressions(self, exprs): - self.path, self.query = exprs - def search_operator(self, compiler, connection): params = { "defaultPath": self.path.as_mql(compiler, connection, as_path=True), - "query": self.query.value, + "query": self.query, } if self.score: params["score"] = self.score.as_mql(compiler, connection) @@ -415,22 +379,16 @@ class SearchRange(SearchExpression): def __init__(self, path, *, lt=None, lte=None, gt=None, gte=None, score=None): self.path = cast_as_field(path) - self.lt = cast_as_value(lt) - self.lte = cast_as_value(lte) - self.gt = cast_as_value(gt) - self.gte = cast_as_value(gte) + self.lt = lt + self.lte = lte + self.gt = gt + self.gte = gte self.score = score super().__init__() def get_search_fields(self, compiler, connection): return {self.path.as_mql(compiler, connection, as_path=True)} - def get_source_expressions(self): - return [self.path, self.lt, self.lte, self.gt, self.gte] - - def set_source_expressions(self, exprs): - self.path, self.lt, self.lte, self.gt, self.gte = exprs - def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), @@ -438,13 +396,13 @@ def search_operator(self, compiler, connection): if self.score: params["score"] = self.score.as_mql(compiler, connection) if self.lt: - params["lt"] = self.lt.value + params["lt"] = self.lt if self.lte: - params["lte"] = self.lte.value + params["lte"] = self.lte if self.gt: - params["gt"] = self.gt.value + params["gt"] = self.gt if self.gte: - params["gte"] = self.gte.value + params["gte"] = self.gte return {"range": params} @@ -467,29 +425,23 @@ class SearchRegex(SearchExpression): def __init__(self, path, query, *, allow_analyzed_field=None, score=None): self.path = cast_as_field(path) - self.query = cast_as_value(query) - self.allow_analyzed_field = cast_as_value(allow_analyzed_field) + self.query = query + self.allow_analyzed_field = allow_analyzed_field self.score = score super().__init__() def get_search_fields(self, compiler, connection): return {self.path.as_mql(compiler, connection, as_path=True)} - def get_source_expressions(self): - return [self.path, self.query, self.allow_analyzed_field] - - def set_source_expressions(self, exprs): - self.path, self.query, self.allow_analyzed_field = exprs - def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), - "query": self.query.value, + "query": self.query, } if self.score: params["score"] = self.score.as_mql(compiler, connection) if self.allow_analyzed_field is not None: - params["allowAnalyzedField"] = self.allow_analyzed_field.value + params["allowAnalyzedField"] = self.allow_analyzed_field return {"regex": params} @@ -521,35 +473,29 @@ class SearchText(SearchExpression): def __init__(self, path, query, *, fuzzy=None, match_criteria=None, synonyms=None, score=None): self.path = cast_as_field(path) - self.query = cast_as_value(query) - self.fuzzy = cast_as_value(fuzzy) - self.match_criteria = cast_as_value(match_criteria) - self.synonyms = cast_as_value(synonyms) + self.query = query + self.fuzzy = fuzzy + self.match_criteria = match_criteria + self.synonyms = synonyms self.score = score super().__init__() def get_search_fields(self, compiler, connection): return {self.path.as_mql(compiler, connection, as_path=True)} - def get_source_expressions(self): - return [self.path, self.query, self.fuzzy, self.match_criteria, self.synonyms] - - def set_source_expressions(self, exprs): - self.path, self.query, self.fuzzy, self.match_criteria, self.synonyms = exprs - def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), - "query": self.query.value, + "query": self.query, } if self.score: params["score"] = self.score.as_mql(compiler, connection) if self.fuzzy is not None: - params["fuzzy"] = self.fuzzy.value + params["fuzzy"] = self.fuzzy if self.match_criteria: - params["matchCriteria"] = self.match_criteria.value + params["matchCriteria"] = self.match_criteria if self.synonyms: - params["synonyms"] = self.synonyms.value + params["synonyms"] = self.synonyms return {"text": params} @@ -575,29 +521,23 @@ class SearchWildcard(SearchExpression): def __init__(self, path, query, allow_analyzed_field=None, score=None): self.path = cast_as_field(path) - self.query = cast_as_value(query) - self.allow_analyzed_field = cast_as_value(allow_analyzed_field) + self.query = query + self.allow_analyzed_field = allow_analyzed_field self.score = score super().__init__() def get_search_fields(self, compiler, connection): return {self.path.as_mql(compiler, connection, as_path=True)} - def get_source_expressions(self): - return [self.path, self.query, self.allow_analyzed_field] - - def set_source_expressions(self, exprs): - self.path, self.query, self.allow_analyzed_field = exprs - def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), - "query": self.query.value, + "query": self.query, } if self.score: params["score"] = self.score.as_mql(compiler, connection) if self.allow_analyzed_field is not None: - params["allowAnalyzedField"] = self.allow_analyzed_field.value + params["allowAnalyzedField"] = self.allow_analyzed_field return {"wildcard": params} @@ -627,25 +567,19 @@ class SearchGeoShape(SearchExpression): def __init__(self, path, relation, geometry, *, score=None): self.path = cast_as_field(path) - self.relation = cast_as_value(relation) - self.geometry = cast_as_value(geometry) + self.relation = relation + self.geometry = geometry self.score = score super().__init__() def get_search_fields(self, compiler, connection): return {self.path.as_mql(compiler, connection, as_path=True)} - def get_source_expressions(self): - return [self.path, self.relation, self.geometry] - - def set_source_expressions(self, exprs): - self.path, self.relation, self.geometry = exprs - def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), - "relation": self.relation.value, - "geometry": self.geometry.value, + "relation": self.relation, + "geometry": self.geometry, } if self.score: params["score"] = self.score.as_mql(compiler, connection) @@ -677,24 +611,18 @@ class SearchGeoWithin(SearchExpression): def __init__(self, path, kind, geometry, *, score=None): self.path = cast_as_field(path) - self.kind = cast_as_value(kind) - self.geometry = cast_as_value(geometry) + self.kind = kind + self.geometry = geometry self.score = score super().__init__() def get_search_fields(self, compiler, connection): return {self.path.as_mql(compiler, connection, as_path=True)} - def get_source_expressions(self): - return [self.path, self.kind, self.geometry] - - def set_source_expressions(self, exprs): - self.path, self.kind, self.geometry = exprs - def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), - self.kind.value: self.geometry.value, + self.kind: self.geometry, } if self.score: params["score"] = self.score.as_mql(compiler, connection) @@ -720,15 +648,15 @@ class SearchMoreLikeThis(SearchExpression): """ def __init__(self, documents, *, score=None): - self.documents = cast_as_value(documents) + self.documents = documents self.score = score super().__init__() def get_source_expressions(self): - return [self.documents] + return [] def set_source_expressions(self, exprs): - (self.documents,) = exprs + pass def search_operator(self, compiler, connection): params = { @@ -798,6 +726,12 @@ def get_search_fields(self, compiler, connection): fields.update(clause.get_search_fields(compiler, connection)) return fields + def get_source_expressions(self): + return [] + + def set_source_expressions(self, exprs): + pass + def resolve_expression( self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False ): @@ -922,11 +856,11 @@ def __init__( filter=None, ): self.path = cast_as_field(path) - self.query_vector = cast_as_value(query_vector) - self.limit = cast_as_value(limit) - self.num_candidates = cast_as_value(num_candidates) - self.exact = cast_as_value(exact) - self.filter = cast_as_value(filter) + self.query_vector = query_vector + self.limit = limit + self.num_candidates = num_candidates + self.exact = exact + self.filter = filter super().__init__() def __invert__(self): @@ -947,26 +881,6 @@ def __ror__(self, other): def get_search_fields(self, compiler, connection): return {self.path.as_mql(compiler, connection, as_path=True)} - def get_source_expressions(self): - return [ - self.path, - self.query_vector, - self.limit, - self.num_candidates, - self.exact, - self.filter, - ] - - def set_source_expressions(self, exprs): - ( - self.path, - self.query_vector, - self.limit, - self.num_candidates, - self.exact, - self.filter, - ) = exprs - def _get_query_index(self, fields, compiler): for search_indexes in compiler.collection.list_search_indexes(): if search_indexes["type"] == "vectorSearch": @@ -981,15 +895,15 @@ def as_mql(self, compiler, connection): params = { "index": self._get_query_index(self.get_search_fields(compiler, connection), compiler), "path": self.path.as_mql(compiler, connection, as_path=True), - "queryVector": self.query_vector.value, - "limit": self.limit.value, + "queryVector": self.query_vector, + "limit": self.limit, } if self.num_candidates: - params["numCandidates"] = self.num_candidates.value + params["numCandidates"] = self.num_candidates if self.exact: - params["exact"] = self.exact.value + params["exact"] = self.exact if self.filter: - params["filter"] = self.filter.value + params["filter"] = self.filter return {"$vectorSearch": params}