4040 func ,
4141)
4242from sqlalchemy_iris import IRISListBuild
43+ from sqlalchemy_iris import IRISVector as IRISVectorType
4344
4445from sqlalchemy .orm import Session
4546
@@ -75,10 +76,13 @@ class DistanceStrategy(str, enum.Enum):
7576
7677class IRISVector (VectorStore ):
7778 _conn = None
79+ native_vector = False
80+ native_vector_cosine_similarity = False
7881
7982 def __init__ (
8083 self ,
8184 embedding_function : Embeddings ,
85+ dimension : int ,
8286 connection_string : Optional [str ] = None ,
8387 collection_name : str = _LANGCHAIN_DEFAULT_COLLECTION_NAME ,
8488 pre_delete_collection : bool = False ,
@@ -91,6 +95,7 @@ def __init__(
9195 ) -> None :
9296 self .connection_string = connection_string or "iris+emb:///"
9397 self .embedding_function = embedding_function
98+ self .dimension = dimension
9499 self .collection_name = collection_name
95100 self .pre_delete_collection = pre_delete_collection
96101 self .collection_metadata = collection_metadata
@@ -113,6 +118,8 @@ def __post_init__(
113118 self .create_vector_functions ()
114119
115120 def create_vector_functions (self ) -> None :
121+ if self .native_vector :
122+ return
116123 try :
117124 with Session (self ._conn ) as session :
118125 session .execute (
@@ -188,6 +195,19 @@ def create_vector_functions(self) -> None:
188195
189196 @property
190197 def distance_strategy (self ) -> str :
198+ if self .native_vector :
199+ if self ._distance_strategy == DistanceStrategy .COSINE :
200+ return self .table .c .embedding .cosine
201+ elif self ._distance_strategy == DistanceStrategy .MAX_INNER_PRODUCT :
202+ return self .table .c .embedding .max_inner_product
203+ # elif self._distance_strategy == DistanceStrategy.EUCLIDEAN:
204+ # return "langchain_l2_distance"
205+ else :
206+ raise ValueError (
207+ f"Got unexpected value for distance: { self ._distance_strategy } . "
208+ f"Should be one of { ', ' .join ([ds .value for ds in DistanceStrategy ])} ."
209+ )
210+
191211 if self ._distance_strategy == DistanceStrategy .EUCLIDEAN :
192212 return "langchain_l2_distance"
193213 elif self ._distance_strategy == DistanceStrategy .COSINE :
@@ -203,6 +223,14 @@ def distance_strategy(self) -> str:
203223 def connect (self ) -> Connection :
204224 engine = create_engine (self .connection_string , ** self .engine_args )
205225 conn = engine .connect ()
226+ try :
227+ if conn .dialect .supports_vectors :
228+ self .native_vector = True
229+ self .native_vector_cosine_similarity = (
230+ conn .dialect .vector_cosine_similarity
231+ )
232+ except : # noqa
233+ pass
206234 return conn
207235
208236 def __del__ (self ) -> None :
@@ -220,7 +248,14 @@ def table(self) -> Table:
220248 self .collection_name ,
221249 Base .metadata ,
222250 Column ("id" , VARCHAR (40 ), primary_key = True , default = uuid .uuid4 ),
223- Column ("embedding" , IRISListBuild (16000 , float )),
251+ Column (
252+ "embedding" ,
253+ (
254+ IRISVectorType (self .dimension )
255+ if self .native_vector
256+ else IRISListBuild (self .dimension , float )
257+ ),
258+ ),
224259 Column ("document" , TEXT , nullable = True ),
225260 Column ("metadata" , TEXT , nullable = True ),
226261 extend_existing = True ,
@@ -278,6 +313,13 @@ def _select_relevance_score_fn(self) -> Callable[[float], float]:
278313 "Consider providing relevance_score_fn to IRISVector constructor."
279314 )
280315
316+ @staticmethod
317+ def _cosine_relevance_score_fn (distance : float ) -> float :
318+ print ('_cosine_relevance_score_fn' , distance )
319+ """Normalize the distance to a score on a scale [0, 1]."""
320+
321+ return round (1.0 - distance , 15 )
322+
281323 @classmethod
282324 def from_embeddings (
283325 cls : Type [IRISVector ],
@@ -299,8 +341,11 @@ def from_embeddings(
299341 texts = [t [0 ] for t in text_embeddings ]
300342 embeddings = [t [1 ] for t in text_embeddings ]
301343
344+ dimension = len (embeddings [0 ])
345+
302346 store = cls (
303347 collection_name = collection_name ,
348+ dimension = dimension ,
304349 distance_strategy = distance_strategy ,
305350 embedding_function = embedding ,
306351 pre_delete_collection = pre_delete_collection ,
@@ -330,8 +375,12 @@ def from_texts(
330375 Return VectorStore initialized from texts and embeddings.
331376 """
332377
378+ sample_embedding = embedding .embed_query ("Hello IRISVector!" )
379+ dimension = len (sample_embedding )
380+
333381 store = cls (
334382 collection_name = collection_name ,
383+ dimension = dimension ,
335384 distance_strategy = distance_strategy ,
336385 embedding_function = embedding ,
337386 pre_delete_collection = pre_delete_collection ,
@@ -483,9 +532,13 @@ def similarity_search_with_score_by_vector(
483532 results : Sequence [Row ] = (
484533 session .query (
485534 self .table ,
486- self .table .c .embedding .func (
487- self .distance_strategy , embedding
488- ).label ("distance" ),
535+ (
536+ self .distance_strategy (embedding ).label ("distance" )
537+ if self .native_vector
538+ else self .table .c .embedding .func (
539+ self .distance_strategy , embedding
540+ ).label ("distance" )
541+ ),
489542 )
490543 .filter (filter_by )
491544 .order_by (asc ("distance" ))
@@ -499,7 +552,7 @@ def similarity_search_with_score_by_vector(
499552 page_content = result .document ,
500553 metadata = json .loads (result .metadata ),
501554 ),
502- float (result .distance ) if self .embedding_function is not None else None ,
555+ round ( float (result .distance ), 15 ) if self .embedding_function is not None else None ,
503556 )
504557 for result in results
505558 ]
0 commit comments