Skip to content

Commit e2fb42d

Browse files
authored
JAVA-3061 Re-introduce an improved CqlVector, add support for accessing vectors directly as float arrays (#1666)
1 parent fc79bb7 commit e2fb42d

File tree

18 files changed

+901
-58
lines changed

18 files changed

+901
-58
lines changed

core/revapi.json

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6887,7 +6887,70 @@
68876887
"code": "java.method.removed",
68886888
"old": "method <T> com.datastax.oss.driver.api.core.type.reflect.GenericType<com.datastax.oss.driver.api.core.data.CqlVector<T>> com.datastax.oss.driver.api.core.type.reflect.GenericType<T>::vectorOf(com.datastax.oss.driver.api.core.type.reflect.GenericType<T>)",
68896889
"justification": "Refactoring in JAVA-3061"
6890-
}
6890+
},
6891+
{
6892+
"code": "java.class.removed",
6893+
"old": "class com.datastax.oss.driver.api.core.data.CqlVector.Builder<T>",
6894+
"justification": "Refactorings in PR 1666"
6895+
},
6896+
{
6897+
"code": "java.method.removed",
6898+
"old": "method com.datastax.oss.driver.api.core.data.CqlVector.Builder com.datastax.oss.driver.api.core.data.CqlVector<T>::builder()",
6899+
"justification": "Refactorings in PR 1666"
6900+
},
6901+
{
6902+
"code": "java.method.removed",
6903+
"old": "method java.lang.Iterable<T> com.datastax.oss.driver.api.core.data.CqlVector<T>::getValues()",
6904+
"justification": "Refactorings in PR 1666"
6905+
},
6906+
{
6907+
"code": "java.generics.formalTypeParameterChanged",
6908+
"old": "class com.datastax.oss.driver.api.core.data.CqlVector<T>",
6909+
"new": "class com.datastax.oss.driver.api.core.data.CqlVector<T extends java.lang.Number>",
6910+
"justification": "Refactorings in PR 1666"
6911+
},
6912+
{
6913+
"code": "java.method.parameterTypeChanged",
6914+
"old": "parameter <SubtypeT> com.datastax.oss.driver.api.core.type.codec.TypeCodec<com.datastax.oss.driver.api.core.data.CqlVector<SubtypeT>> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(===com.datastax.oss.driver.api.core.type.CqlVectorType===, com.datastax.oss.driver.api.core.type.codec.TypeCodec<SubtypeT>)",
6915+
"new": "parameter <SubtypeT extends java.lang.Number> com.datastax.oss.driver.api.core.type.codec.TypeCodec<com.datastax.oss.driver.api.core.data.CqlVector<SubtypeT>> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(===com.datastax.oss.driver.api.core.type.VectorType===, com.datastax.oss.driver.api.core.type.codec.TypeCodec<SubtypeT>)",
6916+
"justification": "Refactorings in PR 1666"
6917+
},
6918+
{
6919+
"code": "java.method.parameterTypeParameterChanged",
6920+
"old": "parameter <SubtypeT> com.datastax.oss.driver.api.core.type.codec.TypeCodec<com.datastax.oss.driver.api.core.data.CqlVector<SubtypeT>> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(com.datastax.oss.driver.api.core.type.CqlVectorType, ===com.datastax.oss.driver.api.core.type.codec.TypeCodec<SubtypeT>===)",
6921+
"new": "parameter <SubtypeT extends java.lang.Number> com.datastax.oss.driver.api.core.type.codec.TypeCodec<com.datastax.oss.driver.api.core.data.CqlVector<SubtypeT>> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(com.datastax.oss.driver.api.core.type.VectorType, ===com.datastax.oss.driver.api.core.type.codec.TypeCodec<SubtypeT>===)",
6922+
"justification": "Refactorings in PR 1666"
6923+
},
6924+
{
6925+
"code": "java.method.returnTypeTypeParametersChanged",
6926+
"old": "method <SubtypeT> com.datastax.oss.driver.api.core.type.codec.TypeCodec<com.datastax.oss.driver.api.core.data.CqlVector<SubtypeT>> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(com.datastax.oss.driver.api.core.type.CqlVectorType, com.datastax.oss.driver.api.core.type.codec.TypeCodec<SubtypeT>)",
6927+
"new": "method <SubtypeT extends java.lang.Number> com.datastax.oss.driver.api.core.type.codec.TypeCodec<com.datastax.oss.driver.api.core.data.CqlVector<SubtypeT>> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(com.datastax.oss.driver.api.core.type.VectorType, com.datastax.oss.driver.api.core.type.codec.TypeCodec<SubtypeT>)",
6928+
"justification": "Refactorings in PR 1666"
6929+
},
6930+
{
6931+
"code": "java.generics.formalTypeParameterChanged",
6932+
"old": "method <SubtypeT> com.datastax.oss.driver.api.core.type.codec.TypeCodec<com.datastax.oss.driver.api.core.data.CqlVector<SubtypeT>> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(com.datastax.oss.driver.api.core.type.CqlVectorType, com.datastax.oss.driver.api.core.type.codec.TypeCodec<SubtypeT>)",
6933+
"new": "method <SubtypeT extends java.lang.Number> com.datastax.oss.driver.api.core.type.codec.TypeCodec<com.datastax.oss.driver.api.core.data.CqlVector<SubtypeT>> com.datastax.oss.driver.api.core.type.codec.TypeCodecs::vectorOf(com.datastax.oss.driver.api.core.type.VectorType, com.datastax.oss.driver.api.core.type.codec.TypeCodec<SubtypeT>)",
6934+
"justification": "Refactorings in PR 1666"
6935+
},
6936+
{
6937+
"code": "java.method.parameterTypeParameterChanged",
6938+
"old": "parameter <T> com.datastax.oss.driver.api.core.type.reflect.GenericType<com.datastax.oss.driver.api.core.data.CqlVector<T>> com.datastax.oss.driver.api.core.type.reflect.GenericType<T>::vectorOf(===com.datastax.oss.driver.api.core.type.reflect.GenericType<T>===)",
6939+
"new": "parameter <T extends java.lang.Number> com.datastax.oss.driver.api.core.type.reflect.GenericType<com.datastax.oss.driver.api.core.data.CqlVector<T>> com.datastax.oss.driver.api.core.type.reflect.GenericType<T>::vectorOf(===com.datastax.oss.driver.api.core.type.reflect.GenericType<T>===)",
6940+
"justification": "Refactorings in PR 1666"
6941+
},
6942+
{
6943+
"code": "java.method.returnTypeTypeParametersChanged",
6944+
"old": "method <T> com.datastax.oss.driver.api.core.type.reflect.GenericType<com.datastax.oss.driver.api.core.data.CqlVector<T>> com.datastax.oss.driver.api.core.type.reflect.GenericType<T>::vectorOf(com.datastax.oss.driver.api.core.type.reflect.GenericType<T>)",
6945+
"new": "method <T extends java.lang.Number> com.datastax.oss.driver.api.core.type.reflect.GenericType<com.datastax.oss.driver.api.core.data.CqlVector<T>> com.datastax.oss.driver.api.core.type.reflect.GenericType<T>::vectorOf(com.datastax.oss.driver.api.core.type.reflect.GenericType<T>)",
6946+
"justification": "Refactorings in PR 1666"
6947+
},
6948+
{
6949+
"code": "java.generics.formalTypeParameterChanged",
6950+
"old": "method <T> com.datastax.oss.driver.api.core.type.reflect.GenericType<com.datastax.oss.driver.api.core.data.CqlVector<T>> com.datastax.oss.driver.api.core.type.reflect.GenericType<T>::vectorOf(com.datastax.oss.driver.api.core.type.reflect.GenericType<T>)",
6951+
"new": "method <T extends java.lang.Number> com.datastax.oss.driver.api.core.type.reflect.GenericType<com.datastax.oss.driver.api.core.data.CqlVector<T>> com.datastax.oss.driver.api.core.type.reflect.GenericType<T>::vectorOf(com.datastax.oss.driver.api.core.type.reflect.GenericType<T>)",
6952+
"justification": "Refactorings in PR 1666"
6953+
}
68916954
]
68926955
}
68936956
}
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
/*
2+
* Copyright DataStax, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package com.datastax.oss.driver.api.core.data;
17+
18+
import com.datastax.oss.driver.api.core.type.codec.TypeCodec;
19+
import com.datastax.oss.driver.shaded.guava.common.base.Preconditions;
20+
import com.datastax.oss.driver.shaded.guava.common.base.Predicates;
21+
import com.datastax.oss.driver.shaded.guava.common.base.Splitter;
22+
import com.datastax.oss.driver.shaded.guava.common.collect.Iterables;
23+
import com.datastax.oss.driver.shaded.guava.common.collect.Streams;
24+
import edu.umd.cs.findbugs.annotations.NonNull;
25+
import java.util.ArrayList;
26+
import java.util.Arrays;
27+
import java.util.Iterator;
28+
import java.util.List;
29+
import java.util.Objects;
30+
import java.util.stream.Collectors;
31+
import java.util.stream.Stream;
32+
33+
/**
34+
* Representation of a vector as defined in CQL.
35+
*
36+
* <p>A CQL vector is a fixed-length array of non-null numeric values. These properties don't map
37+
* cleanly to an existing class in the standard JDK Collections hierarchy so we provide this value
38+
* object instead. Like other value object collections returned by the driver instances of this
39+
* class are not immutable; think of these value objects as a representation of a vector stored in
40+
* the database as an initial step in some additional computation.
41+
*
42+
* <p>While we don't implement any Collection APIs we do implement Iterable. We also attempt to play
43+
* nice with the Streams API in order to better facilitate integration with data pipelines. Finally,
44+
* where possible we've tried to make the API of this class similar to the equivalent methods on
45+
* {@link List}.
46+
*/
47+
public class CqlVector<T extends Number> implements Iterable<T> {
48+
49+
/**
50+
* Create a new CqlVector containing the specified values.
51+
*
52+
* @param vals the collection of values to wrap.
53+
* @return a CqlVector wrapping those values
54+
*/
55+
public static <V extends Number> CqlVector<V> newInstance(V... vals) {
56+
57+
// Note that Array.asList() guarantees the return of an array which implements RandomAccess
58+
return new CqlVector(Arrays.asList(vals));
59+
}
60+
61+
/**
62+
* Create a new CqlVector that "wraps" an existing ArrayList. Modifications to the passed
63+
* ArrayList will also be reflected in the returned CqlVector.
64+
*
65+
* @param list the collection of values to wrap.
66+
* @return a CqlVector wrapping those values
67+
*/
68+
public static <V extends Number> CqlVector<V> newInstance(List<V> list) {
69+
Preconditions.checkArgument(list != null, "Input list should not be null");
70+
return new CqlVector(list);
71+
}
72+
73+
/**
74+
* Create a new CqlVector instance from the specified string representation. Note that this method
75+
* is intended to mirror {@link #toString()}; passing this method the output from a <code>toString
76+
* </code> call on some CqlVector should return a CqlVector that is equal to the origin instance.
77+
*
78+
* @param str a String representation of a CqlVector
79+
* @param subtypeCodec
80+
* @return a new CqlVector built from the String representation
81+
*/
82+
public static <V extends Number> CqlVector<V> from(
83+
@NonNull String str, @NonNull TypeCodec<V> subtypeCodec) {
84+
Preconditions.checkArgument(str != null, "Cannot create CqlVector from null string");
85+
Preconditions.checkArgument(!str.isEmpty(), "Cannot create CqlVector from empty string");
86+
ArrayList<V> vals =
87+
Streams.stream(Splitter.on(", ").split(str.substring(1, str.length() - 1)))
88+
.map(subtypeCodec::parse)
89+
.collect(Collectors.toCollection(ArrayList::new));
90+
return new CqlVector(vals);
91+
}
92+
93+
private final List<T> list;
94+
95+
private CqlVector(@NonNull List<T> list) {
96+
97+
Preconditions.checkArgument(
98+
Iterables.all(list, Predicates.notNull()), "CqlVectors cannot contain null values");
99+
this.list = list;
100+
}
101+
102+
/**
103+
* Retrieve the value at the specified index. Modelled after {@link List#get(int)}
104+
*
105+
* @param idx the index to retrieve
106+
* @return the value at the specified index
107+
*/
108+
public T get(int idx) {
109+
return list.get(idx);
110+
}
111+
112+
/**
113+
* Update the value at the specified index. Modelled after {@link List#set(int, Object)}
114+
*
115+
* @param idx the index to set
116+
* @param val the new value for the specified index
117+
* @return the old value for the specified index
118+
*/
119+
public T set(int idx, T val) {
120+
return list.set(idx, val);
121+
}
122+
123+
/**
124+
* Return the size of this vector. Modelled after {@link List#size()}
125+
*
126+
* @return the vector size
127+
*/
128+
public int size() {
129+
return this.list.size();
130+
}
131+
132+
/**
133+
* Return a CqlVector consisting of the contents of a portion of this vector. Modelled after
134+
* {@link List#subList(int, int)}
135+
*
136+
* @param from the index to start from (inclusive)
137+
* @param to the index to end on (exclusive)
138+
* @return a new CqlVector wrapping the sublist
139+
*/
140+
public CqlVector<T> subVector(int from, int to) {
141+
return new CqlVector<T>(this.list.subList(from, to));
142+
}
143+
144+
/**
145+
* Return a boolean indicating whether the vector is empty. Modelled after {@link List#isEmpty()}
146+
*
147+
* @return true if the list is empty, false otherwise
148+
*/
149+
public boolean isEmpty() {
150+
return this.list.isEmpty();
151+
}
152+
153+
/**
154+
* Create an {@link Iterator} for this vector
155+
*
156+
* @return the generated iterator
157+
*/
158+
@Override
159+
public Iterator<T> iterator() {
160+
return this.list.iterator();
161+
}
162+
163+
/**
164+
* Create a {@link Stream} of the values in this vector
165+
*
166+
* @return the Stream instance
167+
*/
168+
public Stream<T> stream() {
169+
return this.list.stream();
170+
}
171+
172+
@Override
173+
public boolean equals(Object o) {
174+
if (o == this) {
175+
return true;
176+
} else if (o instanceof CqlVector) {
177+
CqlVector that = (CqlVector) o;
178+
return this.list.equals(that.list);
179+
} else {
180+
return false;
181+
}
182+
}
183+
184+
@Override
185+
public int hashCode() {
186+
return Objects.hash(list);
187+
}
188+
189+
@Override
190+
public String toString() {
191+
return Iterables.toString(this.list);
192+
}
193+
}

core/src/main/java/com/datastax/oss/driver/api/core/data/GettableById.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ default CqlDuration getCqlDuration(@NonNull CqlIdentifier id) {
529529
* @throws IllegalArgumentException if the id is invalid.
530530
*/
531531
@Nullable
532-
default <ElementT> List<ElementT> getVector(
532+
default <ElementT extends Number> CqlVector<ElementT> getVector(
533533
@NonNull CqlIdentifier id, @NonNull Class<ElementT> elementsClass) {
534534
return getVector(firstIndexOf(id), elementsClass);
535535
}

core/src/main/java/com/datastax/oss/driver/api/core/data/GettableByIndex.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -444,8 +444,9 @@ default CqlDuration getCqlDuration(int i) {
444444
* @throws IndexOutOfBoundsException if the index is invalid.
445445
*/
446446
@Nullable
447-
default <ElementT> List<ElementT> getVector(int i, @NonNull Class<ElementT> elementsClass) {
448-
return get(i, GenericType.listOf(elementsClass));
447+
default <ElementT extends Number> CqlVector<ElementT> getVector(
448+
int i, @NonNull Class<ElementT> elementsClass) {
449+
return get(i, GenericType.vectorOf(elementsClass));
449450
}
450451

451452
/**

core/src/main/java/com/datastax/oss/driver/api/core/data/GettableByName.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -525,9 +525,9 @@ default CqlDuration getCqlDuration(@NonNull String name) {
525525
* @throws IllegalArgumentException if the name is invalid.
526526
*/
527527
@Nullable
528-
default <ElementT> List<ElementT> getVector(
528+
default <ElementT extends Number> CqlVector<ElementT> getVector(
529529
@NonNull String name, @NonNull Class<ElementT> elementsClass) {
530-
return getList(firstIndexOf(name), elementsClass);
530+
return getVector(firstIndexOf(name), elementsClass);
531531
}
532532

533533
/**

core/src/main/java/com/datastax/oss/driver/api/core/data/SettableById.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -571,9 +571,9 @@ default SelfT setCqlDuration(@NonNull CqlIdentifier id, @Nullable CqlDuration v)
571571
*/
572572
@NonNull
573573
@CheckReturnValue
574-
default <ElementT> SelfT setVector(
574+
default <ElementT extends Number> SelfT setVector(
575575
@NonNull CqlIdentifier id,
576-
@Nullable List<ElementT> v,
576+
@Nullable CqlVector<ElementT> v,
577577
@NonNull Class<ElementT> elementsClass) {
578578
SelfT result = null;
579579
for (Integer i : allIndicesOf(id)) {

core/src/main/java/com/datastax/oss/driver/api/core/data/SettableByIndex.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,9 +423,9 @@ default SelfT setCqlDuration(int i, @Nullable CqlDuration v) {
423423
*/
424424
@NonNull
425425
@CheckReturnValue
426-
default <ElementT> SelfT setVector(
427-
int i, @Nullable List<ElementT> v, @NonNull Class<ElementT> elementsClass) {
428-
return set(i, v, GenericType.listOf(elementsClass));
426+
default <ElementT extends Number> SelfT setVector(
427+
int i, @Nullable CqlVector<ElementT> v, @NonNull Class<ElementT> elementsClass) {
428+
return set(i, v, GenericType.vectorOf(elementsClass));
429429
}
430430

431431
/**

core/src/main/java/com/datastax/oss/driver/api/core/data/SettableByName.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -570,8 +570,10 @@ default SelfT setCqlDuration(@NonNull String name, @Nullable CqlDuration v) {
570570
*/
571571
@NonNull
572572
@CheckReturnValue
573-
default <ElementT> SelfT setVector(
574-
@NonNull String name, @Nullable List<ElementT> v, @NonNull Class<ElementT> elementsClass) {
573+
default <ElementT extends Number> SelfT setVector(
574+
@NonNull String name,
575+
@Nullable CqlVector<ElementT> v,
576+
@NonNull Class<ElementT> elementsClass) {
575577
SelfT result = null;
576578
for (Integer i : allIndicesOf(name)) {
577579
result = (result == null ? this : result).setVector(i, v, elementsClass);

core/src/main/java/com/datastax/oss/driver/api/core/type/codec/ExtraTypeCodecs.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
package com.datastax.oss.driver.api.core.type.codec;
1717

1818
import com.datastax.oss.driver.api.core.session.SessionBuilder;
19+
import com.datastax.oss.driver.api.core.type.DataTypes;
1920
import com.datastax.oss.driver.api.core.type.codec.registry.MutableCodecRegistry;
2021
import com.datastax.oss.driver.api.core.type.reflect.GenericType;
22+
import com.datastax.oss.driver.internal.core.type.DefaultVectorType;
2123
import com.datastax.oss.driver.internal.core.type.codec.SimpleBlobCodec;
2224
import com.datastax.oss.driver.internal.core.type.codec.TimestampCodec;
2325
import com.datastax.oss.driver.internal.core.type.codec.extras.OptionalCodec;
@@ -36,6 +38,7 @@
3638
import com.datastax.oss.driver.internal.core.type.codec.extras.time.PersistentZonedTimestampCodec;
3739
import com.datastax.oss.driver.internal.core.type.codec.extras.time.TimestampMillisCodec;
3840
import com.datastax.oss.driver.internal.core.type.codec.extras.time.ZonedTimestampCodec;
41+
import com.datastax.oss.driver.internal.core.type.codec.extras.vector.FloatVectorToArrayCodec;
3942
import com.fasterxml.jackson.databind.ObjectMapper;
4043
import edu.umd.cs.findbugs.annotations.NonNull;
4144
import java.nio.ByteBuffer;
@@ -479,4 +482,9 @@ public static <T> TypeCodec<T> json(
479482
@NonNull Class<T> javaType, @NonNull ObjectMapper objectMapper) {
480483
return new JsonCodec<>(javaType, objectMapper);
481484
}
485+
486+
/** Builds a new codec that maps CQL float vectors of the specified size to an array of floats. */
487+
public static TypeCodec<float[]> floatVectorToArray(int dimensions) {
488+
return new FloatVectorToArrayCodec(new DefaultVectorType(DataTypes.FLOAT, dimensions));
489+
}
482490
}

core/src/main/java/com/datastax/oss/driver/api/core/type/codec/TypeCodecs.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package com.datastax.oss.driver.api.core.type.codec;
1717

1818
import com.datastax.oss.driver.api.core.data.CqlDuration;
19+
import com.datastax.oss.driver.api.core.data.CqlVector;
1920
import com.datastax.oss.driver.api.core.data.TupleValue;
2021
import com.datastax.oss.driver.api.core.data.UdtValue;
2122
import com.datastax.oss.driver.api.core.type.CustomType;
@@ -207,12 +208,17 @@ public static TypeCodec<TupleValue> tupleOf(@NonNull TupleType cqlType) {
207208
return new TupleCodec(cqlType);
208209
}
209210

210-
public static <SubtypeT> TypeCodec<List<SubtypeT>> vectorOf(
211+
public static <SubtypeT extends Number> TypeCodec<CqlVector<SubtypeT>> vectorOf(
211212
@NonNull VectorType type, @NonNull TypeCodec<SubtypeT> subtypeCodec) {
212213
return new VectorCodec(
213214
DataTypes.vectorOf(subtypeCodec.getCqlType(), type.getDimensions()), subtypeCodec);
214215
}
215216

217+
public static <SubtypeT extends Number> TypeCodec<CqlVector<SubtypeT>> vectorOf(
218+
int dimensions, @NonNull TypeCodec<SubtypeT> subtypeCodec) {
219+
return new VectorCodec(DataTypes.vectorOf(subtypeCodec.getCqlType(), dimensions), subtypeCodec);
220+
}
221+
216222
/**
217223
* Builds a new codec that maps a CQL user defined type to the driver's {@link UdtValue}, for the
218224
* given type definition.

0 commit comments

Comments
 (0)