Skip to content

Commit e78d791

Browse files
authored
Fix ColbertVectorStore as_retriever() (#611)
1 parent a88325d commit e78d791

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

libs/langchain/ragstack_langchain/colbert/colbert_vector_store.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar
22

33
from langchain_core.documents import Document
4-
from langchain_core.retrievers import BaseRetriever
5-
from langchain_core.vectorstores import VectorStore
4+
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
65
from ragstack_colbert import Chunk
76
from ragstack_colbert import ColbertVectorStore as RagstackColbertVectorStore
87
from ragstack_colbert.base_database import BaseDatabase as ColbertBaseDatabase
@@ -13,8 +12,6 @@
1312
from ragstack_colbert.base_vector_store import BaseVectorStore as ColbertBaseVectorStore
1413
from typing_extensions import override
1514

16-
from .colbert_retriever import ColbertRetriever
17-
1815
CVS = TypeVar("CVS", bound="ColbertVectorStore")
1916

2017

@@ -282,8 +279,11 @@ async def afrom_texts(
282279
return instance
283280

284281
@override
285-
def as_retriever(self, k: Optional[int] = 5, **kwargs: Any) -> BaseRetriever:
282+
def as_retriever(self, k: Optional[int] = 5, **kwargs: Any) -> VectorStoreRetriever:
286283
"""Return a VectorStoreRetriever initialized from this VectorStore."""
287-
return ColbertRetriever(
288-
retriever=self._vector_store.as_retriever(), k=k, **kwargs
289-
)
284+
search_kwargs = kwargs.pop("search_kwargs", {})
285+
search_kwargs["k"] = k
286+
search_type = kwargs.get("search_type", "similarity")
287+
if search_type != "similarity":
288+
raise ValueError(f"Unsupported search type: {search_type}")
289+
return super().as_retriever(search_kwargs=search_kwargs, **kwargs)

0 commit comments

Comments
 (0)