Skip to content

Commit e59b277

Browse files
authored
[breaking] Fix ColbertVectorStore from_texts (#612)
* [breaking] Fix ColbertVectorStore from_texts * Wrap BaseEmbeddingModel in a LangChain Embeddings implementation class * Changes following review
1 parent 3feb4fd commit e59b277

File tree

3 files changed

+99
-19
lines changed

3 files changed

+99
-19
lines changed

libs/langchain/ragstack_langchain/colbert/colbert_vector_store.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar
22

33
from langchain_core.documents import Document
4+
from langchain_core.embeddings import Embeddings
45
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
56
from ragstack_colbert import Chunk
67
from ragstack_colbert import ColbertVectorStore as RagstackColbertVectorStore
@@ -12,6 +13,8 @@
1213
from ragstack_colbert.base_vector_store import BaseVectorStore as ColbertBaseVectorStore
1314
from typing_extensions import override
1415

16+
from ragstack_langchain.colbert.embedding import TokensEmbeddings
17+
1518
CVS = TypeVar("CVS", bound="ColbertVectorStore")
1619

1720

@@ -208,8 +211,9 @@ async def asimilarity_search_with_score(
208211
def from_documents(
209212
cls,
210213
documents: List[Document],
211-
database: ColbertBaseDatabase,
212-
embedding_model: ColbertBaseEmbeddingModel,
214+
embedding: Embeddings,
215+
*,
216+
database: Optional[ColbertBaseDatabase] = None,
213217
**kwargs: Any,
214218
) -> CVS:
215219
"""Return VectorStore initialized from documents and embeddings."""
@@ -218,7 +222,7 @@ def from_documents(
218222
return cls.from_texts(
219223
texts=texts,
220224
database=database,
221-
embedding_model=embedding_model,
225+
embedding=embedding,
222226
metadatas=metadatas,
223227
**kwargs,
224228
)
@@ -228,8 +232,9 @@ def from_documents(
228232
async def afrom_documents(
229233
cls: Type[CVS],
230234
documents: List[Document],
231-
database: ColbertBaseDatabase,
232-
embedding_model: ColbertBaseEmbeddingModel,
235+
embedding: Embeddings,
236+
*,
237+
database: Optional[ColbertBaseDatabase] = None,
233238
concurrent_inserts: Optional[int] = 100,
234239
**kwargs: Any,
235240
) -> CVS:
@@ -239,7 +244,7 @@ async def afrom_documents(
239244
return await cls.afrom_texts(
240245
texts=texts,
241246
database=database,
242-
embedding_model=embedding_model,
247+
embedding=embedding,
243248
metadatas=metadatas,
244249
concurrent_inserts=concurrent_inserts,
245250
**kwargs,
@@ -250,13 +255,21 @@ async def afrom_documents(
250255
def from_texts(
251256
cls: Type[CVS],
252257
texts: List[str],
253-
database: ColbertBaseDatabase,
254-
embedding_model: ColbertBaseEmbeddingModel,
258+
embedding: Embeddings,
255259
metadatas: Optional[List[dict]] = None,
260+
*,
261+
database: Optional[ColbertBaseDatabase] = None,
256262
**kwargs: Any,
257263
) -> CVS:
258-
"""Return VectorStore initialized from texts and embeddings."""
259-
instance = cls(database=database, embedding_model=embedding_model, **kwargs)
264+
if not isinstance(embedding, TokensEmbeddings):
265+
raise TypeError("ColbertVectorStore requires a TokensEmbeddings embedding.")
266+
if database is None:
267+
raise ValueError(
268+
"ColbertVectorStore requires a ColbertBaseDatabase database."
269+
)
270+
instance = cls(
271+
database=database, embedding_model=embedding.get_embedding_model(), **kwargs
272+
)
260273
instance.add_texts(texts=texts, metadatas=metadatas)
261274
return instance
262275

@@ -265,14 +278,22 @@ def from_texts(
265278
async def afrom_texts(
266279
cls: Type[CVS],
267280
texts: List[str],
268-
database: ColbertBaseDatabase,
269-
embedding_model: ColbertBaseEmbeddingModel,
281+
embedding: Embeddings,
270282
metadatas: Optional[List[dict]] = None,
283+
*,
284+
database: Optional[ColbertBaseDatabase] = None,
271285
concurrent_inserts: Optional[int] = 100,
272286
**kwargs: Any,
273287
) -> CVS:
274-
"""Return VectorStore initialized from texts and embeddings."""
275-
instance = cls(database=database, embedding_model=embedding_model, **kwargs)
288+
if not isinstance(embedding, TokensEmbeddings):
289+
raise TypeError("ColbertVectorStore requires a TokensEmbeddings embedding.")
290+
if database is None:
291+
raise ValueError(
292+
"ColbertVectorStore requires a ColbertBaseDatabase database."
293+
)
294+
instance = cls(
295+
database=database, embedding_model=embedding.get_embedding_model(), **kwargs
296+
)
276297
await instance.aadd_texts(
277298
texts=texts, metadatas=metadatas, concurrent_inserts=concurrent_inserts
278299
)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from typing import List, Optional
2+
3+
from langchain_core.embeddings import Embeddings
4+
from ragstack_colbert import DEFAULT_COLBERT_MODEL, ColbertEmbeddingModel
5+
from ragstack_colbert.base_embedding_model import BaseEmbeddingModel
6+
from typing_extensions import override
7+
8+
9+
class TokensEmbeddings(Embeddings):
10+
"""Adapter for token-based embedding models and the LangChain Embeddings."""
11+
12+
def __init__(self, embedding: BaseEmbeddingModel = None):
13+
self.embedding = embedding or ColbertEmbeddingModel()
14+
15+
@override
16+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
17+
raise NotImplementedError
18+
19+
@override
20+
def embed_query(self, text: str) -> List[float]:
21+
raise NotImplementedError
22+
23+
@override
24+
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
25+
raise NotImplementedError
26+
27+
@override
28+
async def aembed_query(self, text: str) -> List[float]:
29+
raise NotImplementedError
30+
31+
def get_embedding_model(self) -> BaseEmbeddingModel:
32+
"""Get the embedding model."""
33+
return self.embedding
34+
35+
@staticmethod
36+
def colbert(
37+
checkpoint: str = DEFAULT_COLBERT_MODEL,
38+
doc_maxlen: int = 256,
39+
nbits: int = 2,
40+
kmeans_niters: int = 4,
41+
nranks: int = -1,
42+
query_maxlen: Optional[int] = None,
43+
verbose: int = 3,
44+
chunk_batch_size: int = 640,
45+
):
46+
"""Create a new ColBERT embedding model."""
47+
return TokensEmbeddings(
48+
ColbertEmbeddingModel(
49+
checkpoint,
50+
doc_maxlen,
51+
nbits,
52+
kmeans_niters,
53+
nranks,
54+
query_maxlen,
55+
verbose,
56+
chunk_batch_size,
57+
)
58+
)

libs/langchain/tests/integration_tests/test_colbert.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
from cassandra.cluster import Session
66
from langchain.text_splitter import RecursiveCharacterTextSplitter
77
from langchain_core.documents import Document
8-
from ragstack_colbert import CassandraDatabase, ColbertEmbeddingModel
8+
from ragstack_colbert import CassandraDatabase
99
from ragstack_langchain.colbert import ColbertVectorStore
10+
from ragstack_langchain.colbert.embedding import TokensEmbeddings
1011
from ragstack_tests_utils import TestData
1112
from transformers import BertTokenizer
1213

@@ -72,7 +73,7 @@ def test_sync_from_docs(session: Session) -> None:
7273
batch_size = 5 # 640 recommended for production use
7374
chunk_size = 250
7475

75-
embedding_model = ColbertEmbeddingModel(
76+
embedding = TokensEmbeddings.colbert(
7677
doc_maxlen=chunk_size,
7778
chunk_batch_size=batch_size,
7879
)
@@ -81,7 +82,7 @@ def test_sync_from_docs(session: Session) -> None:
8182

8283
doc_chunks: List[Document] = get_test_chunks()
8384
vector_store: ColbertVectorStore = ColbertVectorStore.from_documents(
84-
documents=doc_chunks, database=database, embedding_model=embedding_model
85+
documents=doc_chunks, database=database, embedding=embedding
8586
)
8687

8788
results: List[Document] = vector_store.similarity_search(
@@ -124,7 +125,7 @@ async def test_async_from_docs(session: Session) -> None:
124125
batch_size = 5 # 640 recommended for production use
125126
chunk_size = 250
126127

127-
embedding_model = ColbertEmbeddingModel(
128+
embedding = TokensEmbeddings.colbert(
128129
doc_maxlen=chunk_size,
129130
chunk_batch_size=batch_size,
130131
)
@@ -133,7 +134,7 @@ async def test_async_from_docs(session: Session) -> None:
133134

134135
doc_chunks: List[Document] = get_test_chunks()
135136
vector_store: ColbertVectorStore = await ColbertVectorStore.afrom_documents(
136-
documents=doc_chunks, database=database, embedding_model=embedding_model
137+
documents=doc_chunks, database=database, embedding=embedding
137138
)
138139

139140
results: List[Document] = await vector_store.asimilarity_search(

0 commit comments

Comments
 (0)