22
33import uuid
44from collections .abc import AsyncGenerator
5+ from contextlib import asynccontextmanager
56from typing import Annotated , Any
67
78import pandas as pd
1011from pydantic import BaseModel
1112
1213from semantic_kernel .connectors .memory .postgres import PostgresStore
13- from semantic_kernel .connectors .memory .postgres .postgres_collection import PostgresCollection
1414from semantic_kernel .connectors .memory .postgres .postgres_settings import PostgresSettings
1515from semantic_kernel .data .const import DistanceFunction , IndexKind
1616from semantic_kernel .data .vector_store_model_decorator import vectorstoremodel
1717from semantic_kernel .data .vector_store_model_definition import VectorStoreRecordDefinition
18- from semantic_kernel .data .vector_store_record_collection import VectorStoreRecordCollection
1918from semantic_kernel .data .vector_store_record_fields import (
2019 VectorStoreRecordDataField ,
2120 VectorStoreRecordKeyField ,
@@ -85,14 +84,22 @@ async def vector_store() -> AsyncGenerator[PostgresStore, None]:
8584 yield PostgresStore (connection_pool = pool )
8685
8786
88- @pytest_asyncio .fixture (scope = "function" )
89- async def simple_collection (vector_store : PostgresStore ):
87+ @asynccontextmanager
88+ async def create_simple_collection (vector_store : PostgresStore ):
89+ """Returns a collection with a unique name that is deleted after the context.
90+
91+ This can be moved to use a fixture with scope=function and loop_scope=session
92+ after upgrade to pytest-asyncio 0.24. With the current version, the fixture
93+ would both cache and use the event loop of the declared scope.
94+ """
9095 suffix = str (uuid .uuid4 ()).replace ("-" , "" )[:8 ]
9196 collection_id = f"test_collection_{ suffix } "
9297 collection = vector_store .get_collection (collection_id , SimpleDataModel )
9398 await collection .create_collection ()
94- yield collection
95- await collection .delete_collection ()
99+ try :
100+ yield collection
101+ finally :
102+ await collection .delete_collection ()
96103
97104
98105def test_create_store (vector_store ):
@@ -118,37 +125,40 @@ async def test_create_does_collection_exist_and_delete(vector_store: PostgresSto
118125
119126
120127@pytest .mark .asyncio (scope = "session" )
121- async def test_list_collection_names (vector_store , simple_collection ):
122- simple_collection_id = simple_collection .collection_name
123- result = await vector_store .list_collection_names ()
124- assert simple_collection_id in result
128+ async def test_list_collection_names (vector_store ):
129+ async with create_simple_collection (vector_store ) as simple_collection :
130+ simple_collection_id = simple_collection .collection_name
131+ result = await vector_store .list_collection_names ()
132+ assert simple_collection_id in result
125133
126134
127135@pytest .mark .asyncio (scope = "session" )
128- async def test_upsert_get_and_delete (simple_collection : PostgresCollection ):
136+ async def test_upsert_get_and_delete (vector_store : PostgresStore ):
129137 record = SimpleDataModel (id = 1 , embedding = [1.1 , 2.2 , 3.3 ], data = {"key" : "value" })
138+ async with create_simple_collection (vector_store ) as simple_collection :
139+ result_before_upsert = await simple_collection .get (1 )
140+ assert result_before_upsert is None
130141
131- result_before_upsert = await simple_collection .get (1 )
132- assert result_before_upsert is None
133-
134- await simple_collection .upsert (record )
135- result = await simple_collection .get (1 )
136- assert result is not None
137- assert result .id == record .id
138- assert result .embedding == record .embedding
139- assert result .data == record .data
140-
141- # Check that the table has an index
142- connection_pool = simple_collection .connection_pool
143- async with connection_pool .connection () as conn , conn .cursor () as cur :
144- await cur .execute ("SELECT indexname FROM pg_indexes WHERE tablename = %s" , (simple_collection .collection_name ,))
145- rows = await cur .fetchall ()
146- index_names = [index [0 ] for index in rows ]
147- assert any ("embedding_idx" in index_name for index_name in index_names )
148-
149- await simple_collection .delete (1 )
150- result_after_delete = await simple_collection .get (1 )
151- assert result_after_delete is None
142+ await simple_collection .upsert (record )
143+ result = await simple_collection .get (1 )
144+ assert result is not None
145+ assert result .id == record .id
146+ assert result .embedding == record .embedding
147+ assert result .data == record .data
148+
149+ # Check that the table has an index
150+ connection_pool = simple_collection .connection_pool
151+ async with connection_pool .connection () as conn , conn .cursor () as cur :
152+ await cur .execute (
153+ "SELECT indexname FROM pg_indexes WHERE tablename = %s" , (simple_collection .collection_name ,)
154+ )
155+ rows = await cur .fetchall ()
156+ index_names = [index [0 ] for index in rows ]
157+ assert any ("embedding_idx" in index_name for index_name in index_names )
158+
159+ await simple_collection .delete (1 )
160+ result_after_delete = await simple_collection .get (1 )
161+ assert result_after_delete is None
152162
153163
154164@pytest .mark .asyncio (scope = "session" )
@@ -182,28 +192,29 @@ async def test_upsert_get_and_delete_pandas(vector_store):
182192
183193
184194@pytest .mark .asyncio (scope = "session" )
185- async def test_upsert_get_and_delete_batch (simple_collection : VectorStoreRecordCollection ):
186- record1 = SimpleDataModel (id = 1 , embedding = [1.1 , 2.2 , 3.3 ], data = {"key" : "value" })
187- record2 = SimpleDataModel (id = 2 , embedding = [4.4 , 5.5 , 6.6 ], data = {"key" : "value" })
188-
189- result_before_upsert = await simple_collection .get_batch ([1 , 2 ])
190- assert result_before_upsert is None
191-
192- await simple_collection .upsert_batch ([record1 , record2 ])
193- # Test get_batch for the two existing keys and one non-existing key;
194- # this should return only the two existing records.
195- result = await simple_collection .get_batch ([1 , 2 , 3 ])
196- assert result is not None
197- assert len (result ) == 2
198- assert result [0 ] is not None
199- assert result [0 ].id == record1 .id
200- assert result [0 ].embedding == record1 .embedding
201- assert result [0 ].data == record1 .data
202- assert result [1 ] is not None
203- assert result [1 ].id == record2 .id
204- assert result [1 ].embedding == record2 .embedding
205- assert result [1 ].data == record2 .data
206-
207- await simple_collection .delete_batch ([1 , 2 ])
208- result_after_delete = await simple_collection .get_batch ([1 , 2 ])
209- assert result_after_delete is None
195+ async def test_upsert_get_and_delete_batch (vector_store : PostgresStore ):
196+ async with create_simple_collection (vector_store ) as simple_collection :
197+ record1 = SimpleDataModel (id = 1 , embedding = [1.1 , 2.2 , 3.3 ], data = {"key" : "value" })
198+ record2 = SimpleDataModel (id = 2 , embedding = [4.4 , 5.5 , 6.6 ], data = {"key" : "value" })
199+
200+ result_before_upsert = await simple_collection .get_batch ([1 , 2 ])
201+ assert result_before_upsert is None
202+
203+ await simple_collection .upsert_batch ([record1 , record2 ])
204+ # Test get_batch for the two existing keys and one non-existing key;
205+ # this should return only the two existing records.
206+ result = await simple_collection .get_batch ([1 , 2 , 3 ])
207+ assert result is not None
208+ assert len (result ) == 2
209+ assert result [0 ] is not None
210+ assert result [0 ].id == record1 .id
211+ assert result [0 ].embedding == record1 .embedding
212+ assert result [0 ].data == record1 .data
213+ assert result [1 ] is not None
214+ assert result [1 ].id == record2 .id
215+ assert result [1 ].embedding == record2 .embedding
216+ assert result [1 ].data == record2 .data
217+
218+ await simple_collection .delete_batch ([1 , 2 ])
219+ result_after_delete = await simple_collection .get_batch ([1 , 2 ])
220+ assert result_after_delete is None
0 commit comments