Skip to content

Commit 4387b2d

Browse files
committed
add test of similarity functions
1 parent 5dd0095 commit 4387b2d

File tree

5 files changed

+64
-20
lines changed

5 files changed

+64
-20
lines changed

query-builder/revapi.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2775,12 +2775,12 @@
27752775
},
27762776
{
27772777
"code": "java.method.addedToInterface",
2778-
"new": "method com.datastax.oss.driver.api.querybuilder.select.Select com.datastax.oss.driver.api.querybuilder.select.Select::orderBy(java.lang.String, com.datastax.oss.driver.api.core.data.CqlVector<? extends java.lang.Number>)",
2778+
"new": "method com.datastax.oss.driver.api.querybuilder.select.Select com.datastax.oss.driver.api.querybuilder.select.Select::orderByAnnOf(java.lang.String, com.datastax.oss.driver.api.core.data.CqlVector<? extends java.lang.Number>)",
27792779
"justification": "JAVA-3118: Add support for vector data type in Schema Builder, QueryBuilder"
27802780
},
27812781
{
27822782
"code": "java.method.addedToInterface",
2783-
"new": "method com.datastax.oss.driver.api.querybuilder.select.Select com.datastax.oss.driver.api.querybuilder.select.Select::orderBy(com.datastax.oss.driver.api.core.CqlIdentifier, com.datastax.oss.driver.api.core.data.CqlVector<? extends java.lang.Number>)",
2783+
"new": "method com.datastax.oss.driver.api.querybuilder.select.Select com.datastax.oss.driver.api.querybuilder.select.Select::orderByAnnOf(com.datastax.oss.driver.api.core.CqlIdentifier, com.datastax.oss.driver.api.core.data.CqlVector<? extends java.lang.Number>)",
27842784
"justification": "JAVA-3118: Add support for vector data type in Schema Builder, QueryBuilder"
27852785
}
27862786
]

query-builder/src/main/java/com/datastax/oss/driver/api/querybuilder/select/Select.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,15 +148,15 @@ default Select orderBy(@NonNull String columnName, @NonNull ClusteringOrder orde
148148
}
149149

150150
/**
151-
* Shortcut for {@link #orderBy(CqlIdentifier, CqlVector)}, adding an ORDER BY ... ANN OF ...
151+
* Shortcut for {@link #orderByAnnOf(CqlIdentifier, CqlVector)}, adding an ORDER BY ... ANN OF ...
152152
* clause
153153
*/
154154
@NonNull
155-
Select orderBy(@NonNull String columnName, @NonNull CqlVector<? extends Number> ann);
155+
Select orderByAnnOf(@NonNull String columnName, @NonNull CqlVector<? extends Number> ann);
156156

157157
/** Adds the ORDER BY ... ANN OF ... clause */
158158
@NonNull
159-
Select orderBy(@NonNull CqlIdentifier columnId, @NonNull CqlVector<? extends Number> ann);
159+
Select orderByAnnOf(@NonNull CqlIdentifier columnId, @NonNull CqlVector<? extends Number> ann);
160160
/**
161161
* Adds a LIMIT clause to this query with a literal value.
162162
*

query-builder/src/main/java/com/datastax/oss/driver/internal/querybuilder/select/DefaultSelect.java

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ public SelectFrom json() {
127127
relations,
128128
groupByClauses,
129129
orderings,
130-
null,
130+
ann,
131131
limit,
132132
perPartitionLimit,
133133
allowsFiltering);
@@ -145,7 +145,7 @@ public SelectFrom distinct() {
145145
relations,
146146
groupByClauses,
147147
orderings,
148-
null,
148+
ann,
149149
limit,
150150
perPartitionLimit,
151151
allowsFiltering);
@@ -205,7 +205,7 @@ public Select withSelectors(@NonNull ImmutableList<Selector> newSelectors) {
205205
relations,
206206
groupByClauses,
207207
orderings,
208-
null,
208+
ann,
209209
limit,
210210
perPartitionLimit,
211211
allowsFiltering);
@@ -234,7 +234,7 @@ public Select withRelations(@NonNull ImmutableList<Relation> newRelations) {
234234
newRelations,
235235
groupByClauses,
236236
orderings,
237-
null,
237+
ann,
238238
limit,
239239
perPartitionLimit,
240240
allowsFiltering);
@@ -263,7 +263,7 @@ public Select withGroupByClauses(@NonNull ImmutableList<Selector> newGroupByClau
263263
relations,
264264
newGroupByClauses,
265265
orderings,
266-
null,
266+
ann,
267267
limit,
268268
perPartitionLimit,
269269
allowsFiltering);
@@ -277,13 +277,14 @@ public Select orderBy(@NonNull CqlIdentifier columnId, @NonNull ClusteringOrder
277277

278278
@NonNull
279279
@Override
280-
public Select orderBy(@NonNull String columnName, @NonNull CqlVector<? extends Number> ann) {
280+
public Select orderByAnnOf(@NonNull String columnName, @NonNull CqlVector<? extends Number> ann) {
281281
return withAnn(new Ann(CqlIdentifier.fromCql(columnName), ann));
282282
}
283283

284284
@NonNull
285285
@Override
286-
public Select orderBy(@NonNull CqlIdentifier columnId, @NonNull CqlVector<? extends Number> ann) {
286+
public Select orderByAnnOf(
287+
@NonNull CqlIdentifier columnId, @NonNull CqlVector<? extends Number> ann) {
287288
return withAnn(new Ann(columnId, ann));
288289
}
289290

@@ -304,7 +305,7 @@ public Select withOrderings(@NonNull ImmutableMap<CqlIdentifier, ClusteringOrder
304305
relations,
305306
groupByClauses,
306307
newOrderings,
307-
null,
308+
ann,
308309
limit,
309310
perPartitionLimit,
310311
allowsFiltering);
@@ -340,7 +341,7 @@ public Select limit(int limit) {
340341
relations,
341342
groupByClauses,
342343
orderings,
343-
null,
344+
ann,
344345
limit,
345346
perPartitionLimit,
346347
allowsFiltering);
@@ -358,7 +359,7 @@ public Select limit(@Nullable BindMarker bindMarker) {
358359
relations,
359360
groupByClauses,
360361
orderings,
361-
null,
362+
ann,
362363
bindMarker,
363364
perPartitionLimit,
364365
allowsFiltering);
@@ -378,7 +379,7 @@ public Select perPartitionLimit(int perPartitionLimit) {
378379
relations,
379380
groupByClauses,
380381
orderings,
381-
null,
382+
ann,
382383
limit,
383384
perPartitionLimit,
384385
allowsFiltering);
@@ -396,7 +397,7 @@ public Select perPartitionLimit(@Nullable BindMarker bindMarker) {
396397
relations,
397398
groupByClauses,
398399
orderings,
399-
null,
400+
ann,
400401
limit,
401402
bindMarker,
402403
allowsFiltering);
@@ -414,7 +415,7 @@ public Select allowFiltering() {
414415
relations,
415416
groupByClauses,
416417
orderings,
417-
null,
418+
ann,
418419
limit,
419420
perPartitionLimit,
420421
true);

query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/select/SelectOrderingTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ public void should_generate_ann_clause() {
8282
selectFrom("foo")
8383
.all()
8484
.where(Relation.column("k").isEqualTo(literal(1)))
85-
.orderBy("c1", CqlVector.newInstance(0.1, 0.2, 0.3)))
85+
.orderByAnnOf("c1", CqlVector.newInstance(0.1, 0.2, 0.3)))
8686
.hasCql("SELECT * FROM foo WHERE k=1 ORDER BY c1 ANN OF [0.1, 0.2, 0.3]");
8787
}
8888

@@ -92,6 +92,6 @@ public void should_fail_when_provided_ann_with_other_orderings() {
9292
.all()
9393
.where(Relation.column("k").isEqualTo(literal(1)))
9494
.orderBy("c1", ASC)
95-
.orderBy("c2", CqlVector.newInstance(0.1, 0.2, 0.3));
95+
.orderByAnnOf("c2", CqlVector.newInstance(0.1, 0.2, 0.3));
9696
}
9797
}

query-builder/src/test/java/com/datastax/oss/driver/api/querybuilder/select/SelectSelectorTest.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.raw;
2323
import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.selectFrom;
2424

25+
import com.datastax.oss.driver.api.core.data.CqlVector;
2526
import com.datastax.oss.driver.api.core.type.DataTypes;
2627
import com.datastax.oss.driver.api.core.type.codec.CodecNotFoundException;
2728
import com.datastax.oss.driver.api.querybuilder.CharsetCodec;
@@ -230,6 +231,48 @@ public void should_generate_raw_selector() {
230231
.hasCql("SELECT bar,baz FROM foo");
231232
}
232233

234+
@Test
235+
public void should_generate_similarity_functions() {
236+
Select similarity_cosine_clause =
237+
selectFrom("cycling", "comments_vs")
238+
.column("comment")
239+
.function(
240+
"similarity_cosine",
241+
Selector.column("comment_vector"),
242+
literal(CqlVector.newInstance(0.2, 0.15, 0.3, 0.2, 0.05)))
243+
.orderByAnnOf("comment_vector", CqlVector.newInstance(0.1, 0.15, 0.3, 0.12, 0.05))
244+
.limit(1);
245+
assertThat(similarity_cosine_clause)
246+
.hasCql(
247+
"SELECT comment,similarity_cosine(comment_vector,[0.2, 0.15, 0.3, 0.2, 0.05]) FROM cycling.comments_vs ORDER BY comment_vector ANN OF [0.1, 0.15, 0.3, 0.12, 0.05] LIMIT 1");
248+
249+
Select similarity_euclidean_clause =
250+
selectFrom("cycling", "comments_vs")
251+
.column("comment")
252+
.function(
253+
"similarity_euclidean",
254+
Selector.column("comment_vector"),
255+
literal(CqlVector.newInstance(0.2, 0.15, 0.3, 0.2, 0.05)))
256+
.orderByAnnOf("comment_vector", CqlVector.newInstance(0.1, 0.15, 0.3, 0.12, 0.05))
257+
.limit(1);
258+
assertThat(similarity_euclidean_clause)
259+
.hasCql(
260+
"SELECT comment,similarity_euclidean(comment_vector,[0.2, 0.15, 0.3, 0.2, 0.05]) FROM cycling.comments_vs ORDER BY comment_vector ANN OF [0.1, 0.15, 0.3, 0.12, 0.05] LIMIT 1");
261+
262+
Select similarity_dot_product_clause =
263+
selectFrom("cycling", "comments_vs")
264+
.column("comment")
265+
.function(
266+
"similarity_dot_product",
267+
Selector.column("comment_vector"),
268+
literal(CqlVector.newInstance(0.2, 0.15, 0.3, 0.2, 0.05)))
269+
.orderByAnnOf("comment_vector", CqlVector.newInstance(0.1, 0.15, 0.3, 0.12, 0.05))
270+
.limit(1);
271+
assertThat(similarity_dot_product_clause)
272+
.hasCql(
273+
"SELECT comment,similarity_dot_product(comment_vector,[0.2, 0.15, 0.3, 0.2, 0.05]) FROM cycling.comments_vs ORDER BY comment_vector ANN OF [0.1, 0.15, 0.3, 0.12, 0.05] LIMIT 1");
274+
}
275+
233276
@Test
234277
public void should_alias_selectors() {
235278
assertThat(selectFrom("foo").column("bar").as("baz")).hasCql("SELECT bar AS baz FROM foo");

0 commit comments

Comments
 (0)