Skip to content

Commit 351f17a

Browse files
committed
Add min_score_confidence support for the Bedrock KB retriver
- add min_score_confidence field to AmazonKnowledgeBasesRetriever class - add unit tests and integration tests to test score - update kendra retriver _filter_by_score_confidence doc string
1 parent 3dfbad4 commit 351f17a

File tree

4 files changed

+92
-3
lines changed

4 files changed

+92
-3
lines changed

libs/aws/langchain_aws/retrievers/bedrock.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
from botocore.exceptions import UnknownServiceError
66
from langchain_core.callbacks import CallbackManagerForRetrieverRun
77
from langchain_core.documents import Document
8-
from langchain_core.pydantic_v1 import BaseModel, root_validator
8+
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
99
from langchain_core.retrievers import BaseRetriever
10+
from typing_extensions import Annotated
1011

1112

1213
class VectorSearchConfig(BaseModel, extra="allow"): # type: ignore[call-arg]
@@ -59,6 +60,7 @@ class AmazonKnowledgeBasesRetriever(BaseRetriever):
5960
endpoint_url: Optional[str] = None
6061
client: Any
6162
retrieval_config: RetrievalConfig
63+
min_score_confidence: Annotated[Optional[float], Field(ge=0.0, le=1.0)]
6264

6365
@root_validator(pre=True)
6466
def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@@ -103,6 +105,23 @@ def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
103105
"profile name are valid."
104106
) from e
105107

108+
def _filter_by_score_confidence(self, docs: List[Document]) -> List[Document]:
109+
"""
110+
Filter out the records that have a score confidence
111+
less than the required threshold.
112+
"""
113+
if not self.min_score_confidence:
114+
return docs
115+
filtered_docs = [
116+
item
117+
for item in docs
118+
if (
119+
item.metadata.get("score") is not None
120+
and item.metadata.get("score", 0.0) >= self.min_score_confidence
121+
)
122+
]
123+
return filtered_docs
124+
106125
def _get_relevant_documents(
107126
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
108127
) -> List[Document]:
@@ -127,4 +146,4 @@ def _get_relevant_documents(
127146
)
128147
)
129148

130-
return documents
149+
return self._filter_by_score_confidence(docs=documents)

libs/aws/langchain_aws/retrievers/kendra.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def _get_top_k_docs(self, result_items: Sequence[ResultItem]) -> List[Document]:
444444
def _filter_by_score_confidence(self, docs: List[Document]) -> List[Document]:
445445
"""
446446
Filter out the records that have a score confidence
447-
greater than the required threshold.
447+
less than the required threshold.
448448
"""
449449
if not self.min_score_confidence:
450450
return docs

libs/aws/tests/integration_tests/retrievers/test_amazon_knowledgebases_retriever.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def retriever(mock_client: Mock) -> AmazonKnowledgeBasesRetriever:
1717
knowledge_base_id="test-knowledge-base",
1818
client=mock_client,
1919
retrieval_config={"vectorSearchConfiguration": {"numberOfResults": 4}}, # type: ignore[arg-type]
20+
min_score_confidence=0.0,
2021
)
2122

2223

@@ -78,3 +79,44 @@ def test_get_relevant_documents(retriever, mock_client) -> None: # type: ignore
7879
knowledgeBaseId="test-knowledge-base",
7980
retrievalConfiguration={"vectorSearchConfiguration": {"numberOfResults": 4}},
8081
)
82+
83+
84+
def test_get_relevant_documents_with_score(retriever, mock_client) -> None: # type: ignore[no-untyped-def]
85+
response = {
86+
"retrievalResults": [
87+
{
88+
"content": {"text": "This is the first result."},
89+
"location": "location1",
90+
"score": 0.9,
91+
},
92+
{
93+
"content": {"text": "This is the second result."},
94+
"location": "location2",
95+
"score": 0.8,
96+
},
97+
{"content": {"text": "This is the third result."}, "location": "location3"},
98+
{
99+
"content": {"text": "This is the fourth result."},
100+
"metadata": {"key1": "value1", "key2": "value2"},
101+
},
102+
]
103+
}
104+
mock_client.retrieve.return_value = response
105+
106+
query = "test query"
107+
108+
expected_documents = [
109+
Document(
110+
page_content="This is the first result.",
111+
metadata={"location": "location1", "score": 0.9},
112+
),
113+
Document(
114+
page_content="This is the second result.",
115+
metadata={"location": "location2", "score": 0.8},
116+
),
117+
]
118+
119+
retriever.min_score_confidence = 0.80
120+
documents = retriever.invoke(query)
121+
122+
assert documents == expected_documents

libs/aws/tests/unit_tests/retrievers/test_bedrock.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,31 @@ def test_retriever_invoke(amazon_retriever, mock_client):
5555
}
5656
assert documents[2].page_content == "result3"
5757
assert documents[2].metadata == {"score": 0}
58+
59+
60+
def test_retriever_invoke_with_score(amazon_retriever, mock_client):
61+
query = "test query"
62+
mock_client.retrieve.return_value = {
63+
"retrievalResults": [
64+
{"content": {"text": "result1"}, "metadata": {"key": "value1"}},
65+
{
66+
"content": {"text": "result2"},
67+
"metadata": {"key": "value2"},
68+
"score": 1,
69+
"location": "testLocation",
70+
},
71+
{"content": {"text": "result3"}},
72+
]
73+
}
74+
75+
amazon_retriever.min_score_confidence = 0.6
76+
documents = amazon_retriever.invoke(query, run_manager=None)
77+
78+
assert len(documents) == 1
79+
assert isinstance(documents[0], Document)
80+
assert documents[0].page_content == "result2"
81+
assert documents[0].metadata == {
82+
"score": 1,
83+
"source_metadata": {"key": "value2"},
84+
"location": "testLocation",
85+
}

0 commit comments

Comments
 (0)