35
35
from langchain_core .embeddings import Embeddings
36
36
from langchain_core .vectorstores import VectorStore
37
37
38
+ from langchain_db2 .utils import EmbeddingsSchema
39
+
38
40
logger = logging .getLogger (__name__ )
39
41
log_level = os .getenv ("LOG_LEVEL" , "ERROR" ).upper ()
40
42
logging .basicConfig (
@@ -221,7 +223,7 @@ def __init__(
221
223
self .client = client
222
224
try :
223
225
"""Initialize with necessary components."""
224
- if not isinstance (embedding_function , Embeddings ):
226
+ if not isinstance (embedding_function , EmbeddingsSchema ):
225
227
logger .warning (
226
228
"`embedding_function` is expected to be an Embeddings "
227
229
"object, support for passing in a function will soon "
@@ -263,7 +265,7 @@ def embeddings(self) -> Optional[Embeddings]:
263
265
"""
264
266
return (
265
267
self .embedding_function
266
- if isinstance (self .embedding_function , Embeddings )
268
+ if isinstance (self .embedding_function , EmbeddingsSchema )
267
269
else None
268
270
)
269
271
@@ -277,7 +279,7 @@ def get_embedding_dimension(self) -> int:
277
279
return len (embedded_document [0 ])
278
280
279
281
def _embed_documents (self , texts : List [str ]) -> List [List [float ]]:
280
- if isinstance (self .embedding_function , Embeddings ):
282
+ if isinstance (self .embedding_function , EmbeddingsSchema ):
281
283
return self .embedding_function .embed_documents (texts )
282
284
elif callable (self .embedding_function ):
283
285
return [self .embedding_function (text ) for text in texts ]
@@ -287,7 +289,7 @@ def _embed_documents(self, texts: List[str]) -> List[List[float]]:
287
289
)
288
290
289
291
def _embed_query (self , text : str ) -> List [float ]:
290
- if isinstance (self .embedding_function , Embeddings ):
292
+ if isinstance (self .embedding_function , EmbeddingsSchema ):
291
293
return self .embedding_function .embed_query (text )
292
294
else :
293
295
return self .embedding_function (text )
@@ -407,7 +409,7 @@ def similarity_search(
407
409
Return:
408
410
List[Document]: documents most similar to a query
409
411
"""
410
- if isinstance (self .embedding_function , Embeddings ):
412
+ if isinstance (self .embedding_function , EmbeddingsSchema ):
411
413
embedding = self .embedding_function .embed_query (query )
412
414
documents = self .similarity_search_by_vector (
413
415
embedding = embedding , k = k , filter = filter , ** kwargs
@@ -434,7 +436,7 @@ def similarity_search_with_score(
434
436
** kwargs : Any ,
435
437
) -> List [Tuple [Document , float ]]:
436
438
"""Return docs most similar to query."""
437
- if isinstance (self .embedding_function , Embeddings ):
439
+ if isinstance (self .embedding_function , EmbeddingsSchema ):
438
440
embedding = self .embedding_function .embed_query (query )
439
441
docs_and_scores = self .similarity_search_by_vector_with_relevance_scores (
440
442
embedding = embedding , k = k , filter = filter , ** kwargs
0 commit comments