diff --git a/app.py b/app.py index 06428ec..8d4b4d9 100644 --- a/app.py +++ b/app.py @@ -80,11 +80,13 @@ def most_similar_by_embedding( include_scores: bool = False, ) -> responses.JSONResponse: """Find the n most similar records with respect to the specified embedding. - Args: + Args: embedding_id (str): Embedding id. record_id (str): The record for which similar records are searched. limit (int): Specifies the maximum amount of returned records. - att_filter(Optional[Dict[str, Any]]]): Specifies the attribute filter for the search as dict objects. + att_filter (Optional[Dict[str, Any]]]): Specifies the attribute filter for the search as dict objects. + Note: Record values can now also be lists (e.g., {"name": ["John", "Alex"]}). + Filters will match if any or all (type "any" or type "all") of the listed values satisfy the condition. threshold: Optional[float]: None = calculated DB threshold, -9999 = no threshold, specified = use value example_filter = [ {"key": "name", "value": ["John", "Doe"]}, -> name IN ("John", "Doe") diff --git a/neural_search/util.py b/neural_search/util.py index 3563c34..422ddf9 100644 --- a/neural_search/util.py +++ b/neural_search/util.py @@ -151,7 +151,10 @@ def __is_label_filter(key: str) -> bool: def __build_filter(att_filter: List[Dict[str, Any]]) -> Optional[models.Filter]: if not att_filter: return None - must = [__build_filter_item(item) for item in att_filter] + must = [] + for item in att_filter: + cond = __build_filter_item(item) + must.append(cond) return models.Filter(must=must) @@ -178,26 +181,32 @@ def __add_access_management_filter( ) -def __build_filter_item(filter_item: Dict[str, Any]) -> models.FieldCondition: +def __build_filter_item( + filter_item: Dict[str, Any], +) -> models.FieldCondition | models.Filter: key = filter_item["key"] value = filter_item["value"] - typ = filter_item.get("type") + type = filter_item.get("type", "any") - # BETWEEN - if isinstance(value, list) and typ == "between": + if isinstance(value, list) and type == "between": return models.FieldCondition( key=key, range=models.Range(gte=value[0], lte=value[1]), ) - # IN (...) - if isinstance(value, list): + if isinstance(value, list) and type == "all": + conditions = [ + models.FieldCondition(key=key, match=models.MatchValue(value=v)) + for v in value + ] + return models.Filter(must=conditions) + + if isinstance(value, list) and type == "any": return models.FieldCondition( key=key, match=models.MatchAny(any=value), ) - # = single value return models.FieldCondition( key=key, match=models.MatchValue(value=value), diff --git a/submodules/model b/submodules/model index b41145a..d5300f1 160000 --- a/submodules/model +++ b/submodules/model @@ -1 +1 @@ -Subproject commit b41145ac4d0284b68c65b88baff034123f5403a5 +Subproject commit d5300f13c3b2cc6da16f36ffd6349795923b2c64