@@ -41,6 +41,8 @@ class VectorDBQA(Chain, BaseModel):
41
41
"""Return the source documents."""
42
42
search_kwargs : Dict [str , Any ] = Field (default_factory = dict )
43
43
"""Extra search args."""
44
+ search_type : str = "similarity"
45
+ """Search type to use over vectorstore. `similarity` or `mmr`."""
44
46
45
47
class Config :
46
48
"""Configuration for this pydantic object."""
@@ -90,6 +92,15 @@ def load_combine_documents_chain(cls, values: Dict) -> Dict:
90
92
values ["combine_documents_chain" ] = combine_documents_chain
91
93
return values
92
94
95
+ @root_validator ()
96
+ def validate_search_type (cls , values : Dict ) -> Dict :
97
+ """Validate search type."""
98
+ if "search_type" in values :
99
+ search_type = values ["search_type" ]
100
+ if search_type not in ("similarity" , "mmr" ):
101
+ raise ValueError (f"search_type of { search_type } not allowed." )
102
+ return values
103
+
93
104
@classmethod
94
105
def from_llm (
95
106
cls , llm : BaseLLM , prompt : PromptTemplate = PROMPT , ** kwargs : Any
@@ -129,9 +140,16 @@ def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
129
140
"""
130
141
question = inputs [self .input_key ]
131
142
132
- docs = self .vectorstore .similarity_search (
133
- question , k = self .k , ** self .search_kwargs
134
- )
143
+ if self .search_type == "similarity" :
144
+ docs = self .vectorstore .similarity_search (
145
+ question , k = self .k , ** self .search_kwargs
146
+ )
147
+ elif self .search_type == "mmr" :
148
+ docs = self .vectorstore .max_marginal_relevance_search (
149
+ question , k = self .k , ** self .search_kwargs
150
+ )
151
+ else :
152
+ raise ValueError (f"search_type of { self .search_type } not allowed." )
135
153
answer , _ = self .combine_documents_chain .combine_docs (docs , question = question )
136
154
137
155
if self .return_source_documents :
0 commit comments