Skip to content

Commit 983b73f

Browse files
authored
add search kwargs (#664)
1 parent 65f3a34 commit 983b73f

File tree

7 files changed

+22
-9
lines changed

7 files changed

+22
-9
lines changed

langchain/chains/qa_with_sources/vector_db.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Question-answering with sources over a vector database."""
22
from typing import Any, Dict, List
33

4-
from pydantic import BaseModel
4+
from pydantic import BaseModel, Field
55

66
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
77
from langchain.docstore.document import Document
@@ -15,8 +15,8 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
1515
"""Vector Database to connect to."""
1616
k: int = 4
1717
"""Number of results to return from store"""
18-
search_kwargs: Dict[str, Any] = {}
19-
"""Extra search args"""
18+
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
19+
"""Extra search args."""
2020

2121
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
2222
question = inputs[self.question_key]

langchain/chains/vector_db_qa/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from typing import Any, Dict, List
55

6-
from pydantic import BaseModel, Extra, root_validator
6+
from pydantic import BaseModel, Extra, Field, root_validator
77

88
from langchain.chains.base import Chain
99
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
@@ -39,6 +39,8 @@ class VectorDBQA(Chain, BaseModel):
3939
output_key: str = "result" #: :meta private:
4040
return_source_documents: bool = False
4141
"""Return the source documents."""
42+
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
43+
"""Extra search args."""
4244

4345
class Config:
4446
"""Configuration for this pydantic object."""
@@ -127,7 +129,9 @@ def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
127129
"""
128130
question = inputs[self.input_key]
129131

130-
docs = self.vectorstore.similarity_search(question, k=self.k)
132+
docs = self.vectorstore.similarity_search(
133+
question, k=self.k, **self.search_kwargs
134+
)
131135
answer, _ = self.combine_documents_chain.combine_docs(docs, question=question)
132136

133137
if self.return_source_documents:

langchain/vectorstores/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ def add_texts(
2626
"""
2727

2828
@abstractmethod
29-
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
29+
def similarity_search(
30+
self, query: str, k: int = 4, **kwargs: Any
31+
) -> List[Document]:
3032
"""Return docs most similar to query."""
3133

3234
def max_marginal_relevance_search(

langchain/vectorstores/elastic_vector_search.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,9 @@ def add_texts(
106106
self.client.indices.refresh(index=self.index_name)
107107
return ids
108108

109-
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
109+
def similarity_search(
110+
self, query: str, k: int = 4, **kwargs: Any
111+
) -> List[Document]:
110112
"""Return docs most similar to query.
111113
112114
Args:

langchain/vectorstores/faiss.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ def similarity_search_with_score(
103103
docs.append((doc, scores[0][j]))
104104
return docs
105105

106-
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
106+
def similarity_search(
107+
self, query: str, k: int = 4, **kwargs: Any
108+
) -> List[Document]:
107109
"""Return docs most similar to query.
108110
109111
Args:

langchain/vectorstores/pinecone.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def similarity_search(
120120
k: int = 5,
121121
filter: Optional[dict] = None,
122122
namespace: Optional[str] = None,
123+
**kwargs: Any,
123124
) -> List[Document]:
124125
"""Return pinecone documents most similar to query.
125126

langchain/vectorstores/weaviate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ def add_texts(
7171
ids.append(_id)
7272
return ids
7373

74-
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
74+
def similarity_search(
75+
self, query: str, k: int = 4, **kwargs: Any
76+
) -> List[Document]:
7577
"""Look up similar documents in weaviate."""
7678
content = {"concepts": [query]}
7779
query_obj = self._client.query.get(self._index_name, self._query_attrs)

0 commit comments

Comments
 (0)