Skip to content

Commit 1833f1e

Browse files
committed
fixed sample tests
1 parent 9d94b67 commit 1833f1e

File tree

2 files changed

+39
-35
lines changed

2 files changed

+39
-35
lines changed

datastore/cloud-client/vector_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def store_vectors():
3030

3131
client.put(entity)
3232
# [END datastore_store_vectors]
33-
return client
33+
return client, entity
3434

3535

3636
def vector_search_basic(db):

datastore/cloud-client/vector_search_test.py

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import os
16+
import pytest
1617

1718
from google.cloud import datastore
1819
from google.cloud.datastore.vector import Vector
@@ -28,12 +29,21 @@
2829
PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"]
2930

3031

31-
def test_store_vectors():
32-
client = store_vectors()
33-
34-
results = client.query("coffee-beans", limit=5).fetch()
32+
@pytest.fixture(scope="module")
33+
def db():
34+
client = datastore.Client()
35+
_clear_db(client)
36+
entity_list = add_coffee_beans_data(client)
37+
yield client
38+
for e in entity_list:
39+
client.delete(e)
3540

36-
assert len(list(results)) == 1
41+
def _clear_db(db):
42+
"""remove all entities with kind-coffee-beans, so we have a new databse"""
43+
query = db.query(kind="coffee-beans")
44+
query.keys_only()
45+
keys = list(query.fetch())
46+
db.delete_multi(keys)
3747

3848

3949
def add_coffee_beans_data(db):
@@ -46,60 +56,54 @@ def add_coffee_beans_data(db):
4656
entity4 = datastore.Entity(db.key("coffee-beans", "Liberica"))
4757
entity4.update({"embedding_field": Vector([3.0, 1.0, 2.0]), "color": "green"})
4858

49-
db.put_multi([entity1, entity2, entity3, entity4])
59+
entity_list = [entity1, entity2, entity3, entity4]
60+
db.put_multi(entity_list)
61+
return entity_list
5062

63+
def test_store_vectors():
64+
# run an ensure there are no exceptions
65+
client, entity = store_vectors()
66+
client.delete(entity)
5167

52-
def test_vector_search_basic():
53-
db = datastore.Client()
54-
add_coffee_beans_data(db)
55-
68+
def test_vector_search_basic(db):
5669
vector_query = vector_search_basic(db)
5770
results = list(vector_query.fetch())
5871

5972
assert len(results) == 4
60-
assert results[0].name == "Liberica"
61-
assert results[1].name == "Robusta"
62-
assert results[2].name == "Arabica"
63-
assert results[3].name == "Excelsa"
73+
assert results[0].key.name == "Liberica"
74+
assert results[1].key.name == "Robusta"
75+
assert results[2].key.name == "Arabica"
76+
assert results[3].key.name == "Excelsa"
6477

6578

66-
def test_vector_search_prefilter():
67-
db = datastore.Client()
68-
add_coffee_beans_data(db)
69-
79+
def test_vector_search_prefilter(db):
7080
vector_query = vector_search_prefilter(db)
7181
results = list(vector_query.fetch())
7282

7383
assert len(results) == 2
74-
assert results[0].name == "Arabica"
75-
assert results[1].name == "Excelsa"
76-
84+
assert results[0].key.name == "Arabica"
85+
assert results[1].key.name == "Excelsa"
7786

78-
def test_vector_search_distance_result_field():
79-
db = datastore.Client()
80-
add_coffee_beans_data(db)
8187

88+
def test_vector_search_distance_result_field(db):
8289
vector_query = vector_search_distance_result_field(db)
8390
results = list(vector_query.fetch())
8491

8592
assert len(results) == 4
86-
assert results[0].name == "Liberica"
93+
assert results[0].key.name == "Liberica"
8794
assert results[0]["vector_distance"] == 0.0
88-
assert results[1].name == "Robusta"
95+
assert results[1].key.name == "Robusta"
8996
assert results[1]["vector_distance"] == 1.0
90-
assert results[2].name == "Arabica"
97+
assert results[2].key.name == "Arabica"
9198
assert results[2]["vector_distance"] == 7.0
92-
assert results[3].name == "Excelsa"
99+
assert results[3].key.name == "Excelsa"
93100
assert results[3]["vector_distance"] == 8.0
94101

95102

96-
def test_vector_search_distance_threshold():
97-
db = datastore.Client()
98-
add_coffee_beans_data(db)
99-
103+
def test_vector_search_distance_threshold(db):
100104
vector_query = vector_search_distance_threshold(db)
101105
results = list(vector_query.fetch())
102106

103107
assert len(results) == 2
104-
assert results[0].name == "Liberica"
105-
assert results[1].name == "Robusta"
108+
assert results[0].key.name == "Liberica"
109+
assert results[1].key.name == "Robusta"

0 commit comments

Comments
 (0)