Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -408,28 +408,58 @@ def _default_hybrid_search_query(
return payload


def _hybrid_search_query_with_post_filter(
def _hybrid_search_query_with_filter(
query_text: str,
query_vector: List[float],
k: int,
post_filter: Dict,
filters: Dict,
) -> Dict:
"""Returns payload for performing hybrid search with post filter.
"""Returns payload for performing hybrid search with filter.

Args:
query_text: The query text to search for.
query_vector: The embedding vector to search for.
k: Number of Documents to return.
post_filter: The post filter to apply.
filter: The post filter to apply.

Returns:
dict: The payload for hybrid search with post filter.
"""
search_query = _default_hybrid_search_query(query_text, query_vector, k)

search_query["post_filter"] = post_filter
payload = {
"_source": {"exclude": ["vector_field"]},
"query": {
"hybrid": {
"queries": [
{
"bool": {
"must": [
{
"match": {
"text": {
"query": query_text,
}
}
}
],
"filter": filters
}
},
{
"knn": {
"vector_field": {
"vector": query_vector,
"k": k,
"filter": filters
}
}
},
]
}
},
"size": k,
}

return search_query
return payload


class OpenSearchVectorSearch(VectorStore):
Expand Down Expand Up @@ -1215,8 +1245,9 @@ def _raw_similarity_search_with_score_by_vector(

elif search_type == HYBRID_SEARCH:
search_pipeline = kwargs.get("search_pipeline")
post_filter = kwargs.get("post_filter", {})
filters = kwargs.get("filter", {})
query_text = kwargs.get("query_text")
index_name = kwargs.get("index_name", self.index_name)
path = f"/{index_name}/_search?search_pipeline={search_pipeline}"

if query_text is None:
Expand All @@ -1225,18 +1256,15 @@ def _raw_similarity_search_with_score_by_vector(
if search_pipeline is None:
raise ValueError("search_pipeline must be provided for hybrid search")

# embedding the query_text
embeded_query = self.embedding_function.embed_query(query_text)

# if post filter is provided
if post_filter != {}:
if filters != {}:
# hybrid search with post filter
payload = _hybrid_search_query_with_post_filter(
query_text, embeded_query, k, post_filter
payload = _hybrid_search_query_with_filter(
query_text, embedding, k, filters
)
else:
# hybrid search without post filter
payload = _default_hybrid_search_query(query_text, embeded_query, k)
payload = _default_hybrid_search_query(query_text, embedding, k)

response = self.client.transport.perform_request(
method="GET", url=path, body=payload
Expand Down
Loading