Skip to content

Commit 3f02c1a

Browse files
fix: Added more tolerant validation for embeddings type (#85)
* fix: Added more tolerant validation for embeddings type * poetry update
1 parent 953e021 commit 3f02c1a

File tree

3 files changed

+154
-135
lines changed

3 files changed

+154
-135
lines changed

libs/langchain-db2/langchain_db2/db2vs.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
from langchain_core.embeddings import Embeddings
3636
from langchain_core.vectorstores import VectorStore
3737

38+
from langchain_db2.utils import EmbeddingsSchema
39+
3840
logger = logging.getLogger(__name__)
3941
log_level = os.getenv("LOG_LEVEL", "ERROR").upper()
4042
logging.basicConfig(
@@ -221,7 +223,7 @@ def __init__(
221223
self.client = client
222224
try:
223225
"""Initialize with necessary components."""
224-
if not isinstance(embedding_function, Embeddings):
226+
if not isinstance(embedding_function, EmbeddingsSchema):
225227
logger.warning(
226228
"`embedding_function` is expected to be an Embeddings "
227229
"object, support for passing in a function will soon "
@@ -263,7 +265,7 @@ def embeddings(self) -> Optional[Embeddings]:
263265
"""
264266
return (
265267
self.embedding_function
266-
if isinstance(self.embedding_function, Embeddings)
268+
if isinstance(self.embedding_function, EmbeddingsSchema)
267269
else None
268270
)
269271

@@ -277,7 +279,7 @@ def get_embedding_dimension(self) -> int:
277279
return len(embedded_document[0])
278280

279281
def _embed_documents(self, texts: List[str]) -> List[List[float]]:
280-
if isinstance(self.embedding_function, Embeddings):
282+
if isinstance(self.embedding_function, EmbeddingsSchema):
281283
return self.embedding_function.embed_documents(texts)
282284
elif callable(self.embedding_function):
283285
return [self.embedding_function(text) for text in texts]
@@ -287,7 +289,7 @@ def _embed_documents(self, texts: List[str]) -> List[List[float]]:
287289
)
288290

289291
def _embed_query(self, text: str) -> List[float]:
290-
if isinstance(self.embedding_function, Embeddings):
292+
if isinstance(self.embedding_function, EmbeddingsSchema):
291293
return self.embedding_function.embed_query(text)
292294
else:
293295
return self.embedding_function(text)
@@ -407,7 +409,7 @@ def similarity_search(
407409
Return:
408410
List[Document]: documents most similar to a query
409411
"""
410-
if isinstance(self.embedding_function, Embeddings):
412+
if isinstance(self.embedding_function, EmbeddingsSchema):
411413
embedding = self.embedding_function.embed_query(query)
412414
documents = self.similarity_search_by_vector(
413415
embedding=embedding, k=k, filter=filter, **kwargs
@@ -434,7 +436,7 @@ def similarity_search_with_score(
434436
**kwargs: Any,
435437
) -> List[Tuple[Document, float]]:
436438
"""Return docs most similar to query."""
437-
if isinstance(self.embedding_function, Embeddings):
439+
if isinstance(self.embedding_function, EmbeddingsSchema):
438440
embedding = self.embedding_function.embed_query(query)
439441
docs_and_scores = self.similarity_search_by_vector_with_relevance_scores(
440442
embedding=embedding, k=k, filter=filter, **kwargs
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from typing import Protocol, runtime_checkable
2+
3+
4+
@runtime_checkable
5+
class EmbeddingsSchema(Protocol):
6+
def embed_documents(self, texts: list[str]) -> list[list[float]]: ...
7+
def embed_query(self, text: str) -> list[float]: ...

0 commit comments

Comments
 (0)