Skip to content

Commit c9319c9

Browse files
committed
added projection sample
1 parent 500613e commit c9319c9

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

datastore/cloud-client/vector_search.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,29 @@ def vector_search_distance_result_property(db):
9595
# [END datastore_vector_search_distance_result_property]
9696
return vector_query
9797

98+
def vector_search_distance_result_property_projection(db):
99+
# [START datastore_vector_search_distance_result_property_projection]
100+
from google.cloud.datastore.vector import DistanceMeasure
101+
from google.cloud.datastore.vector import Vector
102+
from google.cloud.datastore.vector import FindNearest
103+
104+
vector_query = db.query(
105+
kind="coffee-beans",
106+
find_nearest=FindNearest(
107+
vector_property="embedding_field",
108+
query_vector=Vector([3.0, 1.0, 2.0]),
109+
distance_measure=DistanceMeasure.EUCLIDEAN,
110+
limit=5,
111+
distance_result_property="vector_distance",
112+
)
113+
)
114+
vector_query.projection = ["color"]
115+
116+
for entity in vector_query.fetch():
117+
print(f"{entity.id}, Distance: {entity['vector_distance']}")
118+
# [END datastore_vector_search_distance_result_property_projection]
119+
return vector_query
120+
98121

99122
def vector_search_distance_threshold(db):
100123
# [START datastore_vector_search_distance_threshold]

datastore/cloud-client/vector_search_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,33 @@ def test_vector_search_distance_result_property(db):
9393
assert len(results) == 4
9494
assert results[0].key.name == "Liberica"
9595
assert results[0]["vector_distance"] == 0.0
96+
assert results[0]["embedding_field"] == Vector([3.0, 1.0, 2.0])
9697
assert results[1].key.name == "Robusta"
9798
assert results[1]["vector_distance"] == 1.0
99+
assert results[1]["embedding_field"] == Vector([4.0, 1.0, 2.0])
98100
assert results[2].key.name == "Arabica"
99101
assert results[2]["vector_distance"] == 7.0
102+
assert results[2]["embedding_field"] == Vector([10.0, 1.0, 2.0])
100103
assert results[3].key.name == "Excelsa"
101104
assert results[3]["vector_distance"] == 8.0
105+
assert results[3]["embedding_field"] == Vector([11.0, 1.0, 2.0])
106+
107+
108+
def test_vector_search_distance_result_property_projection(db):
109+
vector_query = vector_search_distance_result_property_projection(db)
110+
results = list(vector_query.fetch())
111+
112+
assert len(results) == 4
113+
assert results[0].key.name == "Liberica"
114+
assert results[0]["vector_distance"] == 0.0
115+
assert results[1].key.name == "Robusta"
116+
assert results[1]["vector_distance"] == 1.0
117+
assert results[2].key.name == "Arabica"
118+
assert results[2]["vector_distance"] == 7.0
119+
assert results[3].key.name == "Excelsa"
120+
assert results[3]["vector_distance"] == 8.0
121+
122+
assert all("embedding_field" not in d for d in results)
102123

103124

104125
def test_vector_search_distance_threshold(db):

0 commit comments

Comments
 (0)