Skip to content

Commit b3916f7

Browse files
authored
enable mmr search (#807)
1 parent f46f1d2 commit b3916f7

File tree

1 file changed

+21
-3
lines changed
  • langchain/chains/vector_db_qa

1 file changed

+21
-3
lines changed

langchain/chains/vector_db_qa/base.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ class VectorDBQA(Chain, BaseModel):
4141
"""Return the source documents."""
4242
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
4343
"""Extra search args."""
44+
search_type: str = "similarity"
45+
"""Search type to use over vectorstore. `similarity` or `mmr`."""
4446

4547
class Config:
4648
"""Configuration for this pydantic object."""
@@ -90,6 +92,15 @@ def load_combine_documents_chain(cls, values: Dict) -> Dict:
9092
values["combine_documents_chain"] = combine_documents_chain
9193
return values
9294

95+
@root_validator()
96+
def validate_search_type(cls, values: Dict) -> Dict:
97+
"""Validate search type."""
98+
if "search_type" in values:
99+
search_type = values["search_type"]
100+
if search_type not in ("similarity", "mmr"):
101+
raise ValueError(f"search_type of {search_type} not allowed.")
102+
return values
103+
93104
@classmethod
94105
def from_llm(
95106
cls, llm: BaseLLM, prompt: PromptTemplate = PROMPT, **kwargs: Any
@@ -129,9 +140,16 @@ def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
129140
"""
130141
question = inputs[self.input_key]
131142

132-
docs = self.vectorstore.similarity_search(
133-
question, k=self.k, **self.search_kwargs
134-
)
143+
if self.search_type == "similarity":
144+
docs = self.vectorstore.similarity_search(
145+
question, k=self.k, **self.search_kwargs
146+
)
147+
elif self.search_type == "mmr":
148+
docs = self.vectorstore.max_marginal_relevance_search(
149+
question, k=self.k, **self.search_kwargs
150+
)
151+
else:
152+
raise ValueError(f"search_type of {self.search_type} not allowed.")
135153
answer, _ = self.combine_documents_chain.combine_docs(docs, question=question)
136154

137155
if self.return_source_documents:

0 commit comments

Comments
 (0)