@@ -46,166 +46,165 @@ def connection_string():
4646
4747
4848def test_constructor (connection_string ):
49- memory = PostgresMemoryStore (connection_string , 2 , 1 , 5 )
50- assert memory ._connection_pool is not None
49+ with PostgresMemoryStore (connection_string , 2 , 1 , 5 ) as memory :
50+ assert memory ._connection_pool is not None
5151
5252
5353@pytest .mark .asyncio
5454async def test_create_and_does_collection_exist (connection_string ):
55- memory = PostgresMemoryStore (connection_string , 2 , 1 , 5 )
56- await memory .create_collection ("test_collection" )
57- result = await memory .does_collection_exist ("test_collection" )
58- assert result is not None
55+ with PostgresMemoryStore (connection_string , 2 , 1 , 5 ) as memory :
56+ await memory .create_collection ("test_collection" )
57+ result = await memory .does_collection_exist ("test_collection" )
58+ assert result is not None
5959
6060
6161@pytest .mark .asyncio
6262async def test_get_collections (connection_string ):
63- memory = PostgresMemoryStore (connection_string , 2 , 1 , 5 )
64-
65- try :
66- await memory .create_collection ("test_collection" )
67- result = await memory .get_collections ()
68- assert "test_collection" in result
69- except PoolTimeout :
70- pytest .skip ("PoolTimeout exception raised, skipping test." )
63+ with PostgresMemoryStore (connection_string , 2 , 1 , 5 ) as memory :
64+ try :
65+ await memory .create_collection ("test_collection" )
66+ result = await memory .get_collections ()
67+ assert "test_collection" in result
68+ except PoolTimeout :
69+ pytest .skip ("PoolTimeout exception raised, skipping test." )
7170
7271
7372@pytest .mark .asyncio
7473async def test_delete_collection (connection_string ):
75- memory = PostgresMemoryStore (connection_string , 2 , 1 , 5 )
76- try :
77- await memory .create_collection ("test_collection" )
74+ with PostgresMemoryStore (connection_string , 2 , 1 , 5 ) as memory :
75+ try :
76+ await memory .create_collection ("test_collection" )
7877
79- result = await memory .get_collections ()
80- assert "test_collection" in result
78+ result = await memory .get_collections ()
79+ assert "test_collection" in result
8180
82- await memory .delete_collection ("test_collection" )
83- result = await memory .get_collections ()
84- assert "test_collection" not in result
85- except PoolTimeout :
86- pytest .skip ("PoolTimeout exception raised, skipping test." )
81+ await memory .delete_collection ("test_collection" )
82+ result = await memory .get_collections ()
83+ assert "test_collection" not in result
84+ except PoolTimeout :
85+ pytest .skip ("PoolTimeout exception raised, skipping test." )
8786
8887
8988@pytest .mark .asyncio
9089async def test_does_collection_exist (connection_string ):
91- memory = PostgresMemoryStore (connection_string , 2 , 1 , 5 )
92- try :
93- await memory .create_collection ("test_collection" )
94- result = await memory .does_collection_exist ("test_collection" )
95- assert result is True
96- except PoolTimeout :
97- pytest .skip ("PoolTimeout exception raised, skipping test." )
90+ with PostgresMemoryStore (connection_string , 2 , 1 , 5 ) as memory :
91+ try :
92+ await memory .create_collection ("test_collection" )
93+ result = await memory .does_collection_exist ("test_collection" )
94+ assert result is True
95+ except PoolTimeout :
96+ pytest .skip ("PoolTimeout exception raised, skipping test." )
9897
9998
10099@pytest .mark .asyncio
101100async def test_upsert_and_get (connection_string , memory_record1 ):
102- memory = PostgresMemoryStore (connection_string , 2 , 1 , 5 )
103- try :
104- await memory .create_collection ("test_collection" )
105- await memory .upsert ("test_collection" , memory_record1 )
106- result = await memory .get ("test_collection" , memory_record1 ._id , with_embedding = True )
107- assert result is not None
108- assert result ._id == memory_record1 ._id
109- assert result ._text == memory_record1 ._text
110- assert result ._timestamp == memory_record1 ._timestamp
111- for i in range (len (result ._embedding )):
112- assert result ._embedding [i ] == memory_record1 ._embedding [i ]
113- except PoolTimeout :
114- pytest .skip ("PoolTimeout exception raised, skipping test." )
101+ with PostgresMemoryStore (connection_string , 2 , 1 , 5 ) as memory :
102+ try :
103+ await memory .create_collection ("test_collection" )
104+ await memory .upsert ("test_collection" , memory_record1 )
105+ result = await memory .get ("test_collection" , memory_record1 ._id , with_embedding = True )
106+ assert result is not None
107+ assert result ._id == memory_record1 ._id
108+ assert result ._text == memory_record1 ._text
109+ assert result ._timestamp == memory_record1 ._timestamp
110+ for i in range (len (result ._embedding )):
111+ assert result ._embedding [i ] == memory_record1 ._embedding [i ]
112+ except PoolTimeout :
113+ pytest .skip ("PoolTimeout exception raised, skipping test." )
115114
116115
117116@pytest .mark .asyncio
118117async def test_upsert_batch_and_get_batch (connection_string , memory_record1 , memory_record2 ):
119- memory = PostgresMemoryStore (connection_string , 2 , 1 , 5 )
120- try :
121- await memory .create_collection ("test_collection" )
122- await memory .upsert_batch ("test_collection" , [memory_record1 , memory_record2 ])
123-
124- results = await memory .get_batch (
125- "test_collection" ,
126- [memory_record1 ._id , memory_record2 ._id ],
127- with_embeddings = True ,
128- )
129- assert len (results ) == 2
130- assert results [0 ]._id in [memory_record1 ._id , memory_record2 ._id ]
131- assert results [1 ]._id in [memory_record1 ._id , memory_record2 ._id ]
132- except PoolTimeout :
133- pytest .skip ("PoolTimeout exception raised, skipping test." )
118+ with PostgresMemoryStore (connection_string , 2 , 1 , 5 ) as memory :
119+ try :
120+ await memory .create_collection ("test_collection" )
121+ await memory .upsert_batch ("test_collection" , [memory_record1 , memory_record2 ])
122+
123+ results = await memory .get_batch (
124+ "test_collection" ,
125+ [memory_record1 ._id , memory_record2 ._id ],
126+ with_embeddings = True ,
127+ )
128+ assert len (results ) == 2
129+ assert results [0 ]._id in [memory_record1 ._id , memory_record2 ._id ]
130+ assert results [1 ]._id in [memory_record1 ._id , memory_record2 ._id ]
131+ except PoolTimeout :
132+ pytest .skip ("PoolTimeout exception raised, skipping test." )
134133
135134
136135@pytest .mark .asyncio
137136async def test_remove (connection_string , memory_record1 ):
138- memory = PostgresMemoryStore (connection_string , 2 , 1 , 5 )
139- try :
140- await memory .create_collection ("test_collection" )
141- await memory .upsert ("test_collection" , memory_record1 )
137+ with PostgresMemoryStore (connection_string , 2 , 1 , 5 ) as memory :
138+ try :
139+ await memory .create_collection ("test_collection" )
140+ await memory .upsert ("test_collection" , memory_record1 )
142141
143- result = await memory .get ("test_collection" , memory_record1 ._id , with_embedding = True )
144- assert result is not None
142+ result = await memory .get ("test_collection" , memory_record1 ._id , with_embedding = True )
143+ assert result is not None
145144
146- await memory .remove ("test_collection" , memory_record1 ._id )
147- with pytest .raises (ServiceResourceNotFoundError ):
148- await memory .get ("test_collection" , memory_record1 ._id , with_embedding = True )
149- except PoolTimeout :
150- pytest .skip ("PoolTimeout exception raised, skipping test." )
145+ await memory .remove ("test_collection" , memory_record1 ._id )
146+ with pytest .raises (ServiceResourceNotFoundError ):
147+ await memory .get ("test_collection" , memory_record1 ._id , with_embedding = True )
148+ except PoolTimeout :
149+ pytest .skip ("PoolTimeout exception raised, skipping test." )
151150
152151
153152@pytest .mark .asyncio
154153async def test_remove_batch (connection_string , memory_record1 , memory_record2 ):
155- memory = PostgresMemoryStore (connection_string , 2 , 1 , 5 )
156- try :
157- await memory .create_collection ("test_collection" )
158- await memory .upsert_batch ("test_collection" , [memory_record1 , memory_record2 ])
159- await memory .remove_batch ("test_collection" , [memory_record1 ._id , memory_record2 ._id ])
160- with pytest .raises (ServiceResourceNotFoundError ):
161- _ = await memory .get ("test_collection" , memory_record1 ._id , with_embedding = True )
154+ with PostgresMemoryStore (connection_string , 2 , 1 , 5 ) as memory :
155+ try :
156+ await memory .create_collection ("test_collection" )
157+ await memory .upsert_batch ("test_collection" , [memory_record1 , memory_record2 ])
158+ await memory .remove_batch ("test_collection" , [memory_record1 ._id , memory_record2 ._id ])
159+ with pytest .raises (ServiceResourceNotFoundError ):
160+ _ = await memory .get ("test_collection" , memory_record1 ._id , with_embedding = True )
162161
163- with pytest .raises (ServiceResourceNotFoundError ):
164- _ = await memory .get ("test_collection" , memory_record2 ._id , with_embedding = True )
165- except PoolTimeout :
166- pytest .skip ("PoolTimeout exception raised, skipping test." )
162+ with pytest .raises (ServiceResourceNotFoundError ):
163+ _ = await memory .get ("test_collection" , memory_record2 ._id , with_embedding = True )
164+ except PoolTimeout :
165+ pytest .skip ("PoolTimeout exception raised, skipping test." )
167166
168167
169168@pytest .mark .asyncio
170169async def test_get_nearest_match (connection_string , memory_record1 , memory_record2 ):
171- memory = PostgresMemoryStore (connection_string , 2 , 1 , 5 )
172- try :
173- await memory .create_collection ("test_collection" )
174- await memory .upsert_batch ("test_collection" , [memory_record1 , memory_record2 ])
175- test_embedding = memory_record1 .embedding .copy ()
176- test_embedding [0 ] = test_embedding [0 ] + 0.01
177-
178- result = await memory .get_nearest_match (
179- "test_collection" , test_embedding , min_relevance_score = 0.0 , with_embedding = True
180- )
181- assert result is not None
182- assert result [0 ]._id == memory_record1 ._id
183- assert result [0 ]._text == memory_record1 ._text
184- assert result [0 ]._timestamp == memory_record1 ._timestamp
185- for i in range (len (result [0 ]._embedding )):
186- assert result [0 ]._embedding [i ] == memory_record1 ._embedding [i ]
187- except PoolTimeout :
188- pytest .skip ("PoolTimeout exception raised, skipping test." )
170+ with PostgresMemoryStore (connection_string , 2 , 1 , 5 ) as memory :
171+ try :
172+ await memory .create_collection ("test_collection" )
173+ await memory .upsert_batch ("test_collection" , [memory_record1 , memory_record2 ])
174+ test_embedding = memory_record1 .embedding .copy ()
175+ test_embedding [0 ] = test_embedding [0 ] + 0.01
176+
177+ result = await memory .get_nearest_match (
178+ "test_collection" , test_embedding , min_relevance_score = 0.0 , with_embedding = True
179+ )
180+ assert result is not None
181+ assert result [0 ]._id == memory_record1 ._id
182+ assert result [0 ]._text == memory_record1 ._text
183+ assert result [0 ]._timestamp == memory_record1 ._timestamp
184+ for i in range (len (result [0 ]._embedding )):
185+ assert result [0 ]._embedding [i ] == memory_record1 ._embedding [i ]
186+ except PoolTimeout :
187+ pytest .skip ("PoolTimeout exception raised, skipping test." )
189188
190189
191190@pytest .mark .asyncio
192191async def test_get_nearest_matches (connection_string , memory_record1 , memory_record2 , memory_record3 ):
193- memory = PostgresMemoryStore (connection_string , 2 , 1 , 5 )
194- try :
195- await memory .create_collection ("test_collection" )
196- await memory .upsert_batch ("test_collection" , [memory_record1 , memory_record2 , memory_record3 ])
197- test_embedding = memory_record2 .embedding
198- test_embedding [0 ] = test_embedding [0 ] + 0.025
199-
200- result = await memory .get_nearest_matches (
201- "test_collection" ,
202- test_embedding ,
203- limit = 2 ,
204- min_relevance_score = 0.0 ,
205- with_embeddings = True ,
206- )
207- assert len (result ) == 2
208- assert result [0 ][0 ]._id in [memory_record3 ._id , memory_record2 ._id ]
209- assert result [1 ][0 ]._id in [memory_record3 ._id , memory_record2 ._id ]
210- except PoolTimeout :
211- pytest .skip ("PoolTimeout exception raised, skipping test." )
192+ with PostgresMemoryStore (connection_string , 2 , 1 , 5 ) as memory :
193+ try :
194+ await memory .create_collection ("test_collection" )
195+ await memory .upsert_batch ("test_collection" , [memory_record1 , memory_record2 , memory_record3 ])
196+ test_embedding = memory_record2 .embedding
197+ test_embedding [0 ] = test_embedding [0 ] + 0.025
198+
199+ result = await memory .get_nearest_matches (
200+ "test_collection" ,
201+ test_embedding ,
202+ limit = 2 ,
203+ min_relevance_score = 0.0 ,
204+ with_embeddings = True ,
205+ )
206+ assert len (result ) == 2
207+ assert result [0 ][0 ]._id in [memory_record3 ._id , memory_record2 ._id ]
208+ assert result [1 ][0 ]._id in [memory_record3 ._id , memory_record2 ._id ]
209+ except PoolTimeout :
210+ pytest .skip ("PoolTimeout exception raised, skipping test." )
0 commit comments