Skip to content

Commit 8163c95

Browse files
committed
Add unit tests and system tests
1 parent 9c19afa commit 8163c95

File tree

8 files changed

+294
-25
lines changed

8 files changed

+294
-25
lines changed
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.cloud.datastore;
18+
19+
import com.google.common.base.MoreObjects;
20+
import com.google.common.base.MoreObjects.ToStringHelper;
21+
import com.google.protobuf.ByteString;
22+
import com.google.protobuf.DoubleValue;
23+
import com.google.protobuf.Int32Value;
24+
25+
import java.io.Serializable;
26+
import java.util.Objects;
27+
28+
import javax.annotation.Nullable;
29+
30+
31+
/**
32+
* A query that finds the entities whose vector fields are closest to a certain query vector.
33+
* Create an instance of `FindNearest` with {@link Query#findNearest}.
34+
*/
35+
public final class FindNearest implements Serializable {
36+
37+
private final String vectorProperty;
38+
private final VectorValue queryVector;
39+
private final DistanceMeasure measure;
40+
private final int limit;
41+
/*
42+
Optional. Optional name of the field to output the result of the vector
43+
* distance calculation.
44+
*/
45+
private final @Nullable String distanceResultField;
46+
private final @Nullable Double distanceThreshold;
47+
48+
private static final long serialVersionUID = 4688656124180403551L;
49+
50+
/** Creates a VectorQuery */
51+
public FindNearest(String vectorProperty, VectorValue queryVector, DistanceMeasure measure, int limit, @Nullable String distanceResultField,@Nullable Double distanceThreshold) {
52+
this.vectorProperty = vectorProperty;
53+
this.queryVector = queryVector;
54+
this.measure = measure;
55+
this.limit = limit;
56+
this.distanceResultField = distanceResultField;
57+
this.distanceThreshold = distanceThreshold;
58+
}
59+
60+
public FindNearest(String vectorProperty, VectorValue queryVector, DistanceMeasure measure, int limit) {
61+
this(vectorProperty, queryVector, measure, limit, null, null);
62+
}
63+
64+
public FindNearest(String vectorProperty, VectorValue queryVector, DistanceMeasure measure, int limit, @Nullable String distanceResultField) {
65+
this(vectorProperty, queryVector, measure, limit, distanceResultField, null);
66+
}
67+
68+
public FindNearest(String vectorProperty, VectorValue queryVector, DistanceMeasure measure, int limit, @Nullable Double distanceThreshold) {
69+
this(vectorProperty, queryVector, measure, limit, null, distanceThreshold);
70+
}
71+
72+
@Override
73+
public int hashCode() {
74+
return Objects.hash(vectorProperty, queryVector, measure, limit, distanceResultField, distanceThreshold);
75+
}
76+
77+
/**
78+
* Returns true if this VectorQuery is equal to the provided object.
79+
*
80+
* @param obj The object to compare against.
81+
* @return Whether this VectorQuery is equal to the provided object.
82+
*/
83+
@Override
84+
public boolean equals(Object obj) {
85+
if (this == obj) {
86+
return true;
87+
}
88+
if (obj == null || !(obj instanceof FindNearest)) {
89+
return false;
90+
}
91+
FindNearest otherQuery = (FindNearest) obj;
92+
return Objects.equals(vectorProperty, otherQuery.vectorProperty)
93+
&& Objects.equals(queryVector, otherQuery.queryVector)
94+
&& Objects.equals(distanceResultField, otherQuery.distanceResultField)
95+
&& Objects.equals(distanceThreshold, otherQuery.distanceThreshold)
96+
&& limit == otherQuery.limit
97+
&& measure == otherQuery.measure;
98+
}
99+
100+
@Override
101+
public String toString() {
102+
ToStringHelper toStringHelper = MoreObjects.toStringHelper(this);
103+
toStringHelper.add("vectorProperty", vectorProperty);
104+
toStringHelper.add("queryVector", queryVector);
105+
toStringHelper.add("measure", measure);
106+
toStringHelper.add("limit", limit);
107+
toStringHelper.add("distanceResultField", distanceResultField);
108+
toStringHelper.add("distanceThreshold", distanceThreshold);
109+
return toStringHelper.toString();
110+
}
111+
112+
static FindNearest fromPb(com.google.datastore.v1.FindNearest findNearestPb) {
113+
String vectorProperty = findNearestPb.getVectorProperty().getName();
114+
VectorValue queryVector = VectorValue.MARSHALLER.fromProto(findNearestPb.getQueryVector()).build();
115+
DistanceMeasure distanceMeasure = DistanceMeasure.valueOf(findNearestPb.getDistanceMeasure().toString());
116+
int limit = findNearestPb.getLimit().getValue();
117+
String distanceResultField = findNearestPb.getDistanceResultProperty() == null || findNearestPb.getDistanceResultProperty().isEmpty() ? null:findNearestPb.getDistanceResultProperty();
118+
Double distanceThreshold = findNearestPb.getDistanceThreshold() == null || findNearestPb.getDistanceThreshold() == DoubleValue.getDefaultInstance() ? null : findNearestPb.getDistanceThreshold().getValue();
119+
return new FindNearest(vectorProperty,queryVector, distanceMeasure, limit, distanceResultField, distanceThreshold);
120+
}
121+
122+
com.google.datastore.v1.FindNearest toPb() {
123+
com.google.datastore.v1.FindNearest.Builder findNearestPb = com.google.datastore.v1.FindNearest.newBuilder();
124+
findNearestPb.getVectorPropertyBuilder().setName(vectorProperty);
125+
findNearestPb.setQueryVector(queryVector.toPb());
126+
findNearestPb.setDistanceMeasure(toProto(measure));
127+
findNearestPb.setLimit(Int32Value.of(limit));
128+
if (distanceResultField != null)
129+
{
130+
findNearestPb.setDistanceResultProperty(distanceResultField);
131+
}
132+
if (distanceThreshold != null) {
133+
findNearestPb.setDistanceThreshold(DoubleValue.of(distanceThreshold));
134+
}
135+
return findNearestPb.build();
136+
}
137+
138+
protected static com.google.datastore.v1.FindNearest.DistanceMeasure toProto(
139+
DistanceMeasure distanceMeasure) {
140+
switch (distanceMeasure) {
141+
case COSINE:
142+
return com.google.datastore.v1.FindNearest.DistanceMeasure.COSINE;
143+
case EUCLIDEAN:
144+
return com.google.datastore.v1.FindNearest.DistanceMeasure.EUCLIDEAN;
145+
case DOT_PRODUCT:
146+
return com.google.datastore.v1.FindNearest.DistanceMeasure.DOT_PRODUCT;
147+
default:
148+
return com.google.datastore.v1.FindNearest.DistanceMeasure.UNRECOGNIZED;
149+
}
150+
}
151+
152+
/**
153+
* The distance measure to use when comparing vectors in a {@link FindNearest query}.
154+
*
155+
* @see com.google.cloud.firestore.Query#findNearest
156+
*/
157+
public enum DistanceMeasure {
158+
/**
159+
* COSINE distance compares vectors based on the angle between them, which allows you to measure
160+
* similarity that isn't based on the vectors' magnitude. We recommend using DOT_PRODUCT with
161+
* unit normalized vectors instead of COSINE distance, which is mathematically equivalent with
162+
* better performance.
163+
*/
164+
COSINE,
165+
/** Measures the EUCLIDEAN distance between the vectors. */
166+
EUCLIDEAN,
167+
/** Similar to cosine but is affected by the magnitude of the vectors. */
168+
DOT_PRODUCT
169+
}
170+
}

google-cloud-datastore/src/main/java/com/google/cloud/datastore/StructuredQuery.java

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ public abstract class StructuredQuery<V> extends Query<V> implements RecordQuery
101101
private final Cursor endCursor;
102102
private final int offset;
103103
private final Integer limit;
104+
private final FindNearest findNearest;
104105

105106
private final ResultType<V> resultType;
106107

@@ -731,6 +732,9 @@ public interface Builder<V> {
731732
/** Adds settings to the existing order by clause. */
732733
Builder<V> addOrderBy(OrderBy orderBy, OrderBy... others);
733734

735+
/** Sets the find_nearest for the query. */
736+
Builder<V> setFindNearest(FindNearest findNearest);
737+
734738
StructuredQuery<V> build();
735739
}
736740

@@ -753,6 +757,7 @@ abstract static class BuilderImpl<V, B extends BuilderImpl<V, B>> implements Bui
753757
private Cursor endCursor;
754758
private int offset;
755759
private Integer limit;
760+
private FindNearest findNearest;
756761

757762
BuilderImpl(ResultType<V> resultType) {
758763
this.resultType = resultType;
@@ -770,6 +775,7 @@ abstract static class BuilderImpl<V, B extends BuilderImpl<V, B>> implements Bui
770775
endCursor = query.endCursor;
771776
offset = query.offset;
772777
limit = query.limit;
778+
findNearest = query.findNearest;
773779
}
774780

775781
@SuppressWarnings("unchecked")
@@ -841,6 +847,13 @@ public B addOrderBy(OrderBy orderBy, OrderBy... others) {
841847
return self();
842848
}
843849

850+
@Override
851+
public B setFindNearest(FindNearest findNearest) {
852+
Preconditions.checkArgument(findNearest != null, "vector query must not be null");
853+
this.findNearest = findNearest;
854+
return self();
855+
}
856+
844857
B clearProjection() {
845858
projection.clear();
846859
return self();
@@ -904,6 +917,9 @@ B mergeFrom(com.google.datastore.v1.Query queryPb) {
904917
for (com.google.datastore.v1.PropertyReference distinctOnPb : queryPb.getDistinctOnList()) {
905918
addDistinctOn(distinctOnPb.getName());
906919
}
920+
if (queryPb.getFindNearest() != null) {
921+
setFindNearest(FindNearest.fromPb(queryPb.getFindNearest()));
922+
}
907923
return self();
908924
}
909925
}
@@ -920,6 +936,7 @@ B mergeFrom(com.google.datastore.v1.Query queryPb) {
920936
endCursor = builder.endCursor;
921937
offset = builder.offset;
922938
limit = builder.limit;
939+
findNearest = builder.findNearest;
923940
}
924941

925942
@Override
@@ -935,6 +952,7 @@ public String toString() {
935952
.add("orderBy", orderBy)
936953
.add("projection", projection)
937954
.add("distinctOn", distinctOn)
955+
.add("findNearest", findNearest)
938956
.toString();
939957
}
940958

@@ -950,7 +968,8 @@ public int hashCode() {
950968
filter,
951969
orderBy,
952970
projection,
953-
distinctOn);
971+
distinctOn,
972+
findNearest);
954973
}
955974

956975
@Override
@@ -971,7 +990,8 @@ public boolean equals(Object obj) {
971990
&& Objects.equals(filter, other.filter)
972991
&& Objects.equals(orderBy, other.orderBy)
973992
&& Objects.equals(projection, other.projection)
974-
&& Objects.equals(distinctOn, other.distinctOn);
993+
&& Objects.equals(distinctOn, other.distinctOn)
994+
&& Objects.equals(findNearest, other.findNearest);
975995
}
976996

977997
/** Returns the kind for this query. */
@@ -1023,6 +1043,11 @@ public Integer getLimit() {
10231043
return limit;
10241044
}
10251045

1046+
/** Returns the vector query for this query. */
1047+
public FindNearest getFindNearest() {
1048+
return findNearest;
1049+
}
1050+
10261051
public abstract Builder<V> toBuilder();
10271052

10281053
@InternalApi

google-cloud-datastore/src/main/java/com/google/cloud/datastore/StructuredQueryProtoPreparer.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ public Query prepare(StructuredQuery<?> query) {
6060
.build();
6161
queryPb.addProjection(expressionPb);
6262
}
63+
if (query.getFindNearest() != null) {
64+
queryPb.setFindNearest(query.getFindNearest().toPb());
65+
}
6366

6467
return queryPb.build();
6568
}

google-cloud-datastore/src/test/java/com/google/cloud/datastore/ProtoTestData.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@
2727
import com.google.datastore.v1.PropertyOrder;
2828
import com.google.datastore.v1.PropertyReference;
2929
import com.google.datastore.v1.Value;
30+
import com.google.protobuf.DoubleValue;
31+
import com.google.protobuf.Int32Value;
32+
import com.google.cloud.datastore.FindNearest.DistanceMeasure;
33+
34+
import javax.annotation.Nullable;
3035

3136
public class ProtoTestData {
3237

@@ -83,4 +88,27 @@ public static PropertyOrder propertyOrder(String value) {
8388
public static Projection projection(String value) {
8489
return Projection.newBuilder().setProperty(propertyReference(value)).build();
8590
}
91+
92+
public static com.google.datastore.v1.FindNearest FindNearest(String vectorProperty, VectorValue queryVector, DistanceMeasure measure, int limit) {
93+
return FindNearest(vectorProperty, queryVector, measure, limit, null, null);
94+
}
95+
96+
public static com.google.datastore.v1.FindNearest FindNearest(String vectorProperty, VectorValue queryVector, DistanceMeasure measure, int limit, @Nullable String distanceResultField, @Nullable Double distanceThreshold){
97+
com.google.datastore.v1.FindNearest.Builder builder = com.google.datastore.v1.FindNearest.newBuilder()
98+
.setVectorProperty(propertyReference(vectorProperty))
99+
.setQueryVector(queryVector.toPb())
100+
.setDistanceMeasure(FindNearest.toProto(measure))
101+
.setLimit(Int32Value.of(limit));
102+
103+
if (distanceResultField != null)
104+
{
105+
builder.setDistanceResultProperty(distanceResultField);
106+
}
107+
if (distanceThreshold != null)
108+
{
109+
builder.setDistanceThreshold(DoubleValue.of(distanceThreshold));
110+
}
111+
112+
return builder.build();
113+
}
86114
}

google-cloud-datastore/src/test/java/com/google/cloud/datastore/StructuredQueryProtoPreparerTest.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2022 Google LLC
2+
* Copyright 2024 Google LLC
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -20,6 +20,7 @@
2020
import static com.google.cloud.datastore.ProtoTestData.propertyFilter;
2121
import static com.google.cloud.datastore.ProtoTestData.propertyOrder;
2222
import static com.google.cloud.datastore.ProtoTestData.propertyReference;
23+
import static com.google.cloud.datastore.ProtoTestData.FindNearest;
2324
import static com.google.cloud.datastore.Query.newEntityQueryBuilder;
2425
import static com.google.common.truth.Truth.assertThat;
2526
import static com.google.datastore.v1.PropertyFilter.Operator.EQUAL;
@@ -86,6 +87,15 @@ public void testFilter() {
8687
assertThat(queryProto.getFilter()).isEqualTo(propertyFilter("done", EQUAL, booleanValue(true)));
8788
}
8889

90+
@Test
91+
public void testFindNearest() {
92+
VectorValue VECTOR_VALUE = VectorValue.newBuilder(1.78, 2.56, 3.88).build();
93+
FindNearest FIND_NEAREST = new FindNearest("vector_property", VECTOR_VALUE, FindNearest.DistanceMeasure.COSINE, 1);
94+
Query queryProto = protoPreparer.prepare(newEntityQueryBuilder().setFindNearest(FIND_NEAREST).build());
95+
assertThat(queryProto.getFindNearest())
96+
.isEqualTo(FindNearest("vector_property", VECTOR_VALUE, FindNearest.DistanceMeasure.COSINE, 1));
97+
}
98+
8999
@Test
90100
public void testOrderBy() {
91101
Query queryProto =

0 commit comments

Comments
 (0)