Skip to content

Commit 28efbb0

Browse files
authored
Add params to reduce K dynamically to reduce it below token limit (#739)
Referring to #687, I implemented the functionality to reduce K if it exceeds the token limit. Edit: I should have ran make lint locally. Also, this only applies to `StuffDocumentChain`
1 parent d2f8821 commit 28efbb0

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

langchain/chains/qa_with_sources/vector_db.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Question-answering with sources over a vector database."""
2+
23
from typing import Any, Dict, List
34

45
from pydantic import BaseModel, Field
56

7+
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
68
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
79
from langchain.docstore.document import Document
810
from langchain.vectorstores.base import VectorStore
@@ -15,11 +17,36 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
1517
"""Vector Database to connect to."""
1618
k: int = 4
1719
"""Number of results to return from store"""
20+
reduce_k_below_max_tokens: bool = False
21+
"""Reduce the number of results to return from store based on tokens limit"""
22+
max_tokens_limit: int = 3375
23+
"""Restrict the docs to return from store based on tokens,
24+
enforced only for StuffDocumentChain and if reduce_k_below_max_tokens is to true"""
1825
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
1926
"""Extra search args."""
2027

28+
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:
29+
num_docs = len(docs)
30+
31+
if self.reduce_k_below_max_tokens and isinstance(
32+
self.combine_documents_chain, StuffDocumentsChain
33+
):
34+
tokens = [
35+
self.combine_documents_chain.llm_chain.llm.get_num_tokens(
36+
doc.page_content
37+
)
38+
for doc in docs
39+
]
40+
token_count = sum(tokens[:num_docs])
41+
while token_count > self.max_tokens_limit:
42+
num_docs -= 1
43+
token_count -= tokens[num_docs]
44+
45+
return docs[:num_docs]
46+
2147
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
2248
question = inputs[self.question_key]
23-
return self.vectorstore.similarity_search(
49+
docs = self.vectorstore.similarity_search(
2450
question, k=self.k, **self.search_kwargs
2551
)
52+
return self._reduce_tokens_below_limit(docs)

0 commit comments

Comments
 (0)