diff --git a/libs/chatchat-server/chatchat/server/file_rag/retrievers/milvus_vectorstore.py b/libs/chatchat-server/chatchat/server/file_rag/retrievers/milvus_vectorstore.py index 639b739e5f..dfdedeb919 100644 --- a/libs/chatchat-server/chatchat/server/file_rag/retrievers/milvus_vectorstore.py +++ b/libs/chatchat-server/chatchat/server/file_rag/retrievers/milvus_vectorstore.py @@ -21,18 +21,9 @@ def _get_relevant_documents( if self.search_type == "similarity": docs = self.vectorstore.similarity_search(query, **self.search_kwargs) elif self.search_type == "similarity_score_threshold": - docs_and_similarities = self.vectorstore.similarity_search_with_score(query, **self.search_kwargs) + docs_and_similarities = self.vectorstore.similarity_search_with_relevance_scores(query, **self.search_kwargs) score_threshold = self.search_kwargs.get("score_threshold", None) - if any( - similarity < 0.0 or similarity > 1.0 - for _, similarity in docs_and_similarities - ): - warnings.warn( - "Relevance scores must be between" - f" 0 and 1, got {docs_and_similarities}" - ) - if score_threshold is not None: # can be 0, but not None docs_and_similarities = [ doc @@ -62,22 +53,13 @@ async def _aget_relevant_documents( ) elif self.search_type == "similarity_score_threshold": docs_and_similarities = ( - await self.vectorstore.asimilarity_search_with_score(query, **self.search_kwargs) + await self.vectorstore.asimilarity_search_with_relevance_scores(query, **self.search_kwargs) ) score_threshold = self.search_kwargs.get("score_threshold", None) - - if any( - similarity < 0.0 or similarity > 1.0 - for _, similarity in docs_and_similarities - ): - warnings.warn( - "Relevance scores must be between" - f" 0 and 1, got {docs_and_similarities}" - ) if score_threshold is not None: # can be 0, but not None docs_and_similarities = [ - (doc, similarity) + doc for doc, similarity in docs_and_similarities if similarity >= score_threshold ] diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/milvus_kb_service.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/milvus_kb_service.py index ef06c242e7..9362484ea4 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/milvus_kb_service.py @@ -1,5 +1,5 @@ import os -from typing import Dict, List, Optional +from typing import Callable, Dict, List, Optional from langchain.schema import Document from langchain.vectorstores.milvus import Milvus @@ -79,7 +79,7 @@ def do_search(self, query: str, top_k: int, score_threshold: float): self._load_milvus() # embed_func = get_Embeddings(self.embed_model) # embeddings = embed_func.embed_query(query) - # docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k) + self.milvus._select_relevance_score_fn = self._select_relevance_score_fn retriever = get_Retriever("milvusvectorstore").from_vectorstore( self.milvus, top_k=top_k, @@ -121,6 +121,38 @@ def do_clear_vs(self): self.do_drop_kb() self.do_init() + def _select_relevance_score_fn(self) -> Callable[[float], float]: + def _map_l2_to_similarity(l2_distance: float) -> float: + """Return a similarity score on a scale [0, 1]. + It is recommended that the original vector is normalized, + Milvus only calculates the value before applying square root. + l2_distance range: (0 is most similar, 4 most dissimilar) + See + https://milvus.io/docs/metric.md?tab=floating#Euclidean-distance-L2 + """ + return 1 - l2_distance / 4.0 + + def _map_ip_to_similarity(ip_score: float) -> float: + """Return a similarity score on a scale [0, 1]. + It is recommended that the original vector is normalized, + ip_score range: (1 is most similar, -1 most dissimilar) + See + https://milvus.io/docs/metric.md?tab=floating#Inner-product-IP + https://milvus.io/docs/metric.md?tab=floating#Cosine-Similarity + """ + return (ip_score + 1) / 2.0 + + metric_type = self.milvus.search_params.get("metric_type") + if metric_type == "L2": + return _map_l2_to_similarity + elif metric_type in ["IP", "COSINE"]: + return _map_ip_to_similarity + else: + raise ValueError( + "No supported normalization function" + f" for metric type: {metric_type}." + ) + if __name__ == "__main__": # 测试建表使用