1
1
from typing import Any , Iterable , List , Optional , Tuple , Type , TypeVar
2
2
3
3
from langchain_core .documents import Document
4
+ from langchain_core .embeddings import Embeddings
4
5
from langchain_core .vectorstores import VectorStore , VectorStoreRetriever
5
6
from ragstack_colbert import Chunk
6
7
from ragstack_colbert import ColbertVectorStore as RagstackColbertVectorStore
12
13
from ragstack_colbert .base_vector_store import BaseVectorStore as ColbertBaseVectorStore
13
14
from typing_extensions import override
14
15
16
+ from ragstack_langchain .colbert .embedding import TokensEmbeddings
17
+
15
18
CVS = TypeVar ("CVS" , bound = "ColbertVectorStore" )
16
19
17
20
@@ -208,8 +211,9 @@ async def asimilarity_search_with_score(
208
211
def from_documents (
209
212
cls ,
210
213
documents : List [Document ],
211
- database : ColbertBaseDatabase ,
212
- embedding_model : ColbertBaseEmbeddingModel ,
214
+ embedding : Embeddings ,
215
+ * ,
216
+ database : Optional [ColbertBaseDatabase ] = None ,
213
217
** kwargs : Any ,
214
218
) -> CVS :
215
219
"""Return VectorStore initialized from documents and embeddings."""
@@ -218,7 +222,7 @@ def from_documents(
218
222
return cls .from_texts (
219
223
texts = texts ,
220
224
database = database ,
221
- embedding_model = embedding_model ,
225
+ embedding = embedding ,
222
226
metadatas = metadatas ,
223
227
** kwargs ,
224
228
)
@@ -228,8 +232,9 @@ def from_documents(
228
232
async def afrom_documents (
229
233
cls : Type [CVS ],
230
234
documents : List [Document ],
231
- database : ColbertBaseDatabase ,
232
- embedding_model : ColbertBaseEmbeddingModel ,
235
+ embedding : Embeddings ,
236
+ * ,
237
+ database : Optional [ColbertBaseDatabase ] = None ,
233
238
concurrent_inserts : Optional [int ] = 100 ,
234
239
** kwargs : Any ,
235
240
) -> CVS :
@@ -239,7 +244,7 @@ async def afrom_documents(
239
244
return await cls .afrom_texts (
240
245
texts = texts ,
241
246
database = database ,
242
- embedding_model = embedding_model ,
247
+ embedding = embedding ,
243
248
metadatas = metadatas ,
244
249
concurrent_inserts = concurrent_inserts ,
245
250
** kwargs ,
@@ -250,13 +255,21 @@ async def afrom_documents(
250
255
def from_texts (
251
256
cls : Type [CVS ],
252
257
texts : List [str ],
253
- database : ColbertBaseDatabase ,
254
- embedding_model : ColbertBaseEmbeddingModel ,
258
+ embedding : Embeddings ,
255
259
metadatas : Optional [List [dict ]] = None ,
260
+ * ,
261
+ database : Optional [ColbertBaseDatabase ] = None ,
256
262
** kwargs : Any ,
257
263
) -> CVS :
258
- """Return VectorStore initialized from texts and embeddings."""
259
- instance = cls (database = database , embedding_model = embedding_model , ** kwargs )
264
+ if not isinstance (embedding , TokensEmbeddings ):
265
+ raise TypeError ("ColbertVectorStore requires a TokensEmbeddings embedding." )
266
+ if database is None :
267
+ raise ValueError (
268
+ "ColbertVectorStore requires a ColbertBaseDatabase database."
269
+ )
270
+ instance = cls (
271
+ database = database , embedding_model = embedding .get_embedding_model (), ** kwargs
272
+ )
260
273
instance .add_texts (texts = texts , metadatas = metadatas )
261
274
return instance
262
275
@@ -265,14 +278,22 @@ def from_texts(
265
278
async def afrom_texts (
266
279
cls : Type [CVS ],
267
280
texts : List [str ],
268
- database : ColbertBaseDatabase ,
269
- embedding_model : ColbertBaseEmbeddingModel ,
281
+ embedding : Embeddings ,
270
282
metadatas : Optional [List [dict ]] = None ,
283
+ * ,
284
+ database : Optional [ColbertBaseDatabase ] = None ,
271
285
concurrent_inserts : Optional [int ] = 100 ,
272
286
** kwargs : Any ,
273
287
) -> CVS :
274
- """Return VectorStore initialized from texts and embeddings."""
275
- instance = cls (database = database , embedding_model = embedding_model , ** kwargs )
288
+ if not isinstance (embedding , TokensEmbeddings ):
289
+ raise TypeError ("ColbertVectorStore requires a TokensEmbeddings embedding." )
290
+ if database is None :
291
+ raise ValueError (
292
+ "ColbertVectorStore requires a ColbertBaseDatabase database."
293
+ )
294
+ instance = cls (
295
+ database = database , embedding_model = embedding .get_embedding_model (), ** kwargs
296
+ )
276
297
await instance .aadd_texts (
277
298
texts = texts , metadatas = metadatas , concurrent_inserts = concurrent_inserts
278
299
)
0 commit comments