@@ -46,6 +46,15 @@ class Item(Base):
4646)
4747index .create (engine )
4848
49+ half_precision_index = Index (
50+ 'sqlalchemy_orm_half_precision_index' ,
51+ func .cast (Item .embedding , HALFVEC (3 )).label ('embedding' ),
52+ postgresql_using = 'hnsw' ,
53+ postgresql_with = {'m' : 16 , 'ef_construction' : 64 },
54+ postgresql_ops = {'embedding' : 'halfvec_l2_ops' }
55+ )
56+ half_precision_index .create (engine )
57+
4958
5059def create_items ():
5160 session = Session (engine )
@@ -438,6 +447,12 @@ def test_vector_array(self):
438447 assert item .embeddings [0 ].tolist () == [1 , 2 , 3 ]
439448 assert item .embeddings [1 ].tolist () == [4 , 5 , 6 ]
440449
450+ def test_half_precision (self ):
451+ create_items ()
452+ with Session (engine ) as session :
453+ items = session .query (Item ).order_by (func .cast (Item .embedding , HALFVEC (3 )).l2_distance ([1 , 1 , 1 ])).all ()
454+ assert [v .id for v in items ] == [1 , 3 , 2 ]
455+
441456 @pytest .mark .asyncio
442457 @pytest .mark .skipif (sqlalchemy_version == 1 , reason = 'Requires SQLAlchemy 2+' )
443458 async def test_async (self ):
0 commit comments