Skip to content

Commit e4d600a

Browse files
committed
Native IRIS Vector datatype
1 parent bf187ff commit e4d600a

File tree

4 files changed

+65
-14
lines changed

4 files changed

+65
-14
lines changed

docker-compose.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
version: "3"
22
services:
33
iris:
4-
image: intersystemsdc/iris-community:2023.3-zpm
4+
# image: intersystemsdc/iris-community:2023.3-zpm
5+
image: caretdev/iris-community:2024.1-vecdb
56
ports:
67
- 6172:1972
78
- 6173:52773

langchain_iris/vectorstores.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
func,
4141
)
4242
from sqlalchemy_iris import IRISListBuild
43+
from sqlalchemy_iris import IRISVector as IRISVectorType
4344

4445
from sqlalchemy.orm import Session
4546

@@ -75,10 +76,13 @@ class DistanceStrategy(str, enum.Enum):
7576

7677
class IRISVector(VectorStore):
7778
_conn = None
79+
native_vector = False
80+
native_vector_cosine_similarity = False
7881

7982
def __init__(
8083
self,
8184
embedding_function: Embeddings,
85+
dimension: int,
8286
connection_string: Optional[str] = None,
8387
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
8488
pre_delete_collection: bool = False,
@@ -91,6 +95,7 @@ def __init__(
9195
) -> None:
9296
self.connection_string = connection_string or "iris+emb:///"
9397
self.embedding_function = embedding_function
98+
self.dimension = dimension
9499
self.collection_name = collection_name
95100
self.pre_delete_collection = pre_delete_collection
96101
self.collection_metadata = collection_metadata
@@ -113,6 +118,8 @@ def __post_init__(
113118
self.create_vector_functions()
114119

115120
def create_vector_functions(self) -> None:
121+
if self.native_vector:
122+
return
116123
try:
117124
with Session(self._conn) as session:
118125
session.execute(
@@ -188,6 +195,19 @@ def create_vector_functions(self) -> None:
188195

189196
@property
190197
def distance_strategy(self) -> str:
198+
if self.native_vector:
199+
if self._distance_strategy == DistanceStrategy.COSINE:
200+
return self.table.c.embedding.cosine
201+
elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
202+
return self.table.c.embedding.max_inner_product
203+
# elif self._distance_strategy == DistanceStrategy.EUCLIDEAN:
204+
# return "langchain_l2_distance"
205+
else:
206+
raise ValueError(
207+
f"Got unexpected value for distance: {self._distance_strategy}. "
208+
f"Should be one of {', '.join([ds.value for ds in DistanceStrategy])}."
209+
)
210+
191211
if self._distance_strategy == DistanceStrategy.EUCLIDEAN:
192212
return "langchain_l2_distance"
193213
elif self._distance_strategy == DistanceStrategy.COSINE:
@@ -203,6 +223,14 @@ def distance_strategy(self) -> str:
203223
def connect(self) -> Connection:
204224
engine = create_engine(self.connection_string, **self.engine_args)
205225
conn = engine.connect()
226+
try:
227+
if conn.dialect.supports_vectors:
228+
self.native_vector = True
229+
self.native_vector_cosine_similarity = (
230+
conn.dialect.vector_cosine_similarity
231+
)
232+
except: # noqa
233+
pass
206234
return conn
207235

208236
def __del__(self) -> None:
@@ -220,7 +248,14 @@ def table(self) -> Table:
220248
self.collection_name,
221249
Base.metadata,
222250
Column("id", VARCHAR(40), primary_key=True, default=uuid.uuid4),
223-
Column("embedding", IRISListBuild(16000, float)),
251+
Column(
252+
"embedding",
253+
(
254+
IRISVectorType(self.dimension)
255+
if self.native_vector
256+
else IRISListBuild(self.dimension, float)
257+
),
258+
),
224259
Column("document", TEXT, nullable=True),
225260
Column("metadata", TEXT, nullable=True),
226261
extend_existing=True,
@@ -278,6 +313,13 @@ def _select_relevance_score_fn(self) -> Callable[[float], float]:
278313
"Consider providing relevance_score_fn to IRISVector constructor."
279314
)
280315

316+
@staticmethod
317+
def _cosine_relevance_score_fn(distance: float) -> float:
318+
print('_cosine_relevance_score_fn', distance)
319+
"""Normalize the distance to a score on a scale [0, 1]."""
320+
321+
return round(1.0 - distance, 15)
322+
281323
@classmethod
282324
def from_embeddings(
283325
cls: Type[IRISVector],
@@ -299,8 +341,11 @@ def from_embeddings(
299341
texts = [t[0] for t in text_embeddings]
300342
embeddings = [t[1] for t in text_embeddings]
301343

344+
dimension = len(embeddings[0])
345+
302346
store = cls(
303347
collection_name=collection_name,
348+
dimension=dimension,
304349
distance_strategy=distance_strategy,
305350
embedding_function=embedding,
306351
pre_delete_collection=pre_delete_collection,
@@ -330,8 +375,12 @@ def from_texts(
330375
Return VectorStore initialized from texts and embeddings.
331376
"""
332377

378+
sample_embedding = embedding.embed_query("Hello IRISVector!")
379+
dimension = len(sample_embedding)
380+
333381
store = cls(
334382
collection_name=collection_name,
383+
dimension=dimension,
335384
distance_strategy=distance_strategy,
336385
embedding_function=embedding,
337386
pre_delete_collection=pre_delete_collection,
@@ -483,9 +532,13 @@ def similarity_search_with_score_by_vector(
483532
results: Sequence[Row] = (
484533
session.query(
485534
self.table,
486-
self.table.c.embedding.func(
487-
self.distance_strategy, embedding
488-
).label("distance"),
535+
(
536+
self.distance_strategy(embedding).label("distance")
537+
if self.native_vector
538+
else self.table.c.embedding.func(
539+
self.distance_strategy, embedding
540+
).label("distance")
541+
),
489542
)
490543
.filter(filter_by)
491544
.order_by(asc("distance"))
@@ -499,7 +552,7 @@ def similarity_search_with_score_by_vector(
499552
page_content=result.document,
500553
metadata=json.loads(result.metadata),
501554
),
502-
float(result.distance) if self.embedding_function is not None else None,
555+
round(float(result.distance), 15) if self.embedding_function is not None else None,
503556
)
504557
for result in results
505558
]

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
),
88
install_requires=[
99
"langchain==0.0.348",
10-
"sqlalchemy-iris>=0.12.0",
10+
"sqlalchemy-iris>=0.13.0",
1111
],
1212
python_requires=">3.7,<3.12",
1313
)

tests/test_vectorstores.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from langchain.docstore.document import Document
99
from sqlalchemy.orm import Session
1010

11-
from langchain_iris import IRISVector, DistanceStrategy
11+
from langchain_iris import IRISVector
1212
from langchain.embeddings.fake import DeterministicFakeEmbedding
1313

1414

@@ -138,7 +138,7 @@ def test_irisvector_with_filter_distant_match(
138138
)
139139
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "2"})
140140
assert output == [
141-
(Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406)
141+
(Document(page_content="baz", metadata={"page": "2"}), 0.001300390667138)
142142
]
143143

144144

@@ -232,13 +232,10 @@ def test_irisvector_relevance_score(collection_name, connection_string) -> None:
232232
)
233233

234234
output = docsearch.similarity_search_with_relevance_scores("foo", k=3)
235-
print([[doc.page_content, score] for doc, score in output])
236-
# [1.0, 0.07647878038412692, -0.10694835587301355]
237-
# [1.0, 0.07647878038412692, -0.10694835587301355]
238235
assert output == [
239236
(Document(page_content="foo", metadata={"page": "0"}), 1.0),
240-
(Document(page_content="bar", metadata={"page": "1"}), 0.9996744261675065),
241-
(Document(page_content="baz", metadata={"page": "2"}), 0.9986996093328621),
237+
(Document(page_content="bar", metadata={"page": "1"}), 0.999674426167506),
238+
(Document(page_content="baz", metadata={"page": "2"}), 0.998699609332862),
242239
]
243240

244241

0 commit comments

Comments
 (0)