1313# limitations under the License.
1414
1515import os
16+ import pytest
1617
1718from google .cloud import datastore
1819from google .cloud .datastore .vector import Vector
2829PROJECT_ID = os .environ ["GOOGLE_CLOUD_PROJECT" ]
2930
3031
31- def test_store_vectors ():
32- client = store_vectors ()
33-
34- results = client .query ("coffee-beans" , limit = 5 ).fetch ()
32+ @pytest .fixture (scope = "module" )
33+ def db ():
34+ client = datastore .Client ()
35+ _clear_db (client )
36+ entity_list = add_coffee_beans_data (client )
37+ yield client
38+ for e in entity_list :
39+ client .delete (e )
3540
36- assert len (list (results )) == 1
41+ def _clear_db (db ):
42+ """remove all entities with kind-coffee-beans, so we have a new databse"""
43+ query = db .query (kind = "coffee-beans" )
44+ query .keys_only ()
45+ keys = list (query .fetch ())
46+ db .delete_multi (keys )
3747
3848
3949def add_coffee_beans_data (db ):
@@ -46,60 +56,54 @@ def add_coffee_beans_data(db):
4656 entity4 = datastore .Entity (db .key ("coffee-beans" , "Liberica" ))
4757 entity4 .update ({"embedding_field" : Vector ([3.0 , 1.0 , 2.0 ]), "color" : "green" })
4858
49- db .put_multi ([entity1 , entity2 , entity3 , entity4 ])
59+ entity_list = [entity1 , entity2 , entity3 , entity4 ]
60+ db .put_multi (entity_list )
61+ return entity_list
5062
63+ def test_store_vectors ():
64+ # run an ensure there are no exceptions
65+ client , entity = store_vectors ()
66+ client .delete (entity )
5167
52- def test_vector_search_basic ():
53- db = datastore .Client ()
54- add_coffee_beans_data (db )
55-
68+ def test_vector_search_basic (db ):
5669 vector_query = vector_search_basic (db )
5770 results = list (vector_query .fetch ())
5871
5972 assert len (results ) == 4
60- assert results [0 ].name == "Liberica"
61- assert results [1 ].name == "Robusta"
62- assert results [2 ].name == "Arabica"
63- assert results [3 ].name == "Excelsa"
73+ assert results [0 ].key . name == "Liberica"
74+ assert results [1 ].key . name == "Robusta"
75+ assert results [2 ].key . name == "Arabica"
76+ assert results [3 ].key . name == "Excelsa"
6477
6578
66- def test_vector_search_prefilter ():
67- db = datastore .Client ()
68- add_coffee_beans_data (db )
69-
79+ def test_vector_search_prefilter (db ):
7080 vector_query = vector_search_prefilter (db )
7181 results = list (vector_query .fetch ())
7282
7383 assert len (results ) == 2
74- assert results [0 ].name == "Arabica"
75- assert results [1 ].name == "Excelsa"
76-
84+ assert results [0 ].key .name == "Arabica"
85+ assert results [1 ].key .name == "Excelsa"
7786
78- def test_vector_search_distance_result_field ():
79- db = datastore .Client ()
80- add_coffee_beans_data (db )
8187
88+ def test_vector_search_distance_result_field (db ):
8289 vector_query = vector_search_distance_result_field (db )
8390 results = list (vector_query .fetch ())
8491
8592 assert len (results ) == 4
86- assert results [0 ].name == "Liberica"
93+ assert results [0 ].key . name == "Liberica"
8794 assert results [0 ]["vector_distance" ] == 0.0
88- assert results [1 ].name == "Robusta"
95+ assert results [1 ].key . name == "Robusta"
8996 assert results [1 ]["vector_distance" ] == 1.0
90- assert results [2 ].name == "Arabica"
97+ assert results [2 ].key . name == "Arabica"
9198 assert results [2 ]["vector_distance" ] == 7.0
92- assert results [3 ].name == "Excelsa"
99+ assert results [3 ].key . name == "Excelsa"
93100 assert results [3 ]["vector_distance" ] == 8.0
94101
95102
96- def test_vector_search_distance_threshold ():
97- db = datastore .Client ()
98- add_coffee_beans_data (db )
99-
103+ def test_vector_search_distance_threshold (db ):
100104 vector_query = vector_search_distance_threshold (db )
101105 results = list (vector_query .fetch ())
102106
103107 assert len (results ) == 2
104- assert results [0 ].name == "Liberica"
105- assert results [1 ].name == "Robusta"
108+ assert results [0 ].key . name == "Liberica"
109+ assert results [1 ].key . name == "Robusta"
0 commit comments