Skip to content

Commit c2980a0

Browse files
committed
regarding to PR reviews
1 parent 76b2be9 commit c2980a0

File tree

4 files changed

+58
-79
lines changed

4 files changed

+58
-79
lines changed

core/src/main/java/com/datastax/oss/driver/api/core/data/CqlVector.java

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import com.datastax.oss.driver.shaded.guava.common.base.Preconditions;
2323
import com.datastax.oss.driver.shaded.guava.common.base.Predicates;
2424
import com.datastax.oss.driver.shaded.guava.common.collect.Iterables;
25+
import com.google.common.collect.Iterables;
2526
import edu.umd.cs.findbugs.annotations.NonNull;
2627
import java.io.IOException;
2728
import java.io.InvalidObjectException;
@@ -34,7 +35,6 @@
3435
import java.util.Iterator;
3536
import java.util.List;
3637
import java.util.Objects;
37-
import java.util.stream.Collectors;
3838
import java.util.stream.Stream;
3939

4040
/**
@@ -235,10 +235,7 @@ public int hashCode() {
235235
*/
236236
@Override
237237
public String toString() {
238-
if (this.list.isEmpty()) return "[]";
239-
return this.list.stream()
240-
.map(ele -> ele.toString())
241-
.collect(Collectors.joining(", ", "[", "]"));
238+
return Iterables.toString(this.list);
242239
}
243240

244241
/**

core/src/main/java/com/datastax/oss/driver/api/core/type/DataTypes.java

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -65,29 +65,8 @@ public static DataType custom(@NonNull String className) {
6565
if (className.equals("org.apache.cassandra.db.marshal.DurationType")) return DURATION;
6666

6767
/* Vector support is currently implemented as a custom type but is also parameterized */
68-
if (className.startsWith(DefaultVectorType.VECTOR_CLASS_NAME)) {
69-
String paramsString =
70-
className.substring(
71-
DefaultVectorType.VECTOR_CLASS_NAME.length() + 1, className.length() - 1);
72-
int lastCommaIndex = paramsString.lastIndexOf(',');
73-
if (lastCommaIndex == -1) {
74-
throw new IllegalArgumentException(
75-
String.format(
76-
"Invalid vector type %s, expected format is %s<subtype, dimensions>",
77-
className, DefaultVectorType.VECTOR_CLASS_NAME));
78-
}
79-
String subTypeString = paramsString.substring(0, lastCommaIndex).trim();
80-
String dimensionsString = paramsString.substring(lastCommaIndex + 1).trim();
81-
82-
DataType subType = classNameParser.parse(subTypeString, AttachmentPoint.NONE);
83-
int dimensions = Integer.parseInt(dimensionsString);
84-
if (dimensions <= 0) {
85-
throw new IllegalArgumentException(
86-
String.format(
87-
"Request to create vector of size %d, size must be positive", dimensions));
88-
}
89-
return new DefaultVectorType(subType, dimensions);
90-
}
68+
if (className.startsWith(DefaultVectorType.VECTOR_CLASS_NAME))
69+
return classNameParser.parse(className, AttachmentPoint.NONE);
9170
return new DefaultCustomType(className);
9271
}
9372

core/src/main/java/com/datastax/oss/driver/internal/core/metadata/schema/parsing/DataTypeClassNameParser.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import com.datastax.oss.protocol.internal.util.Bytes;
3535
import java.util.ArrayList;
3636
import java.util.Collections;
37+
import java.util.Iterator;
3738
import java.util.LinkedHashMap;
3839
import java.util.List;
3940
import java.util.Map;
@@ -164,6 +165,13 @@ private DataType parse(
164165
return new DefaultTupleType(componentTypesBuilder.build(), attachmentPoint);
165166
}
166167

168+
if (next.startsWith("org.apache.cassandra.db.marshal.VectorType")) {
169+
Iterator<String> rawTypes = parser.getTypeParameters().iterator();
170+
DataType subtype = parse(rawTypes.next(), userTypes, attachmentPoint, logPrefix);
171+
int dimensions = Integer.parseInt(rawTypes.next());
172+
return DataTypes.vectorOf(subtype, dimensions);
173+
}
174+
167175
DataType type = NATIVE_TYPES_BY_CLASS_NAME.get(next);
168176
return type == null ? DataTypes.custom(toParse) : type;
169177
}

core/src/main/java/com/datastax/oss/driver/internal/core/type/codec/VectorCodec.java

Lines changed: 46 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,11 @@ public ByteBuffer encode(
103103
if (valueBuff == null) {
104104
throw new NullPointerException("Vector elements cannot encode to CQL NULL");
105105
}
106+
int elementSize = valueBuff.limit();
106107
if (isVarSized) {
107-
int elementSize = valueBuff.limit();
108-
allValueBuffsSize += elementSize + VIntCoding.computeVIntSize(elementSize);
109-
} else {
110-
allValueBuffsSize += valueBuff.limit();
108+
allValueBuffsSize += VIntCoding.computeVIntSize(elementSize);
111109
}
110+
allValueBuffsSize += elementSize;
112111
valueBuff.rewind();
113112
valueBuffs[i] = valueBuff;
114113
}
@@ -139,56 +138,52 @@ public CqlVector<SubtypeT> decode(
139138
if (bytes == null || bytes.remaining() == 0) {
140139
return null;
141140
}
142-
boolean isVarSized = !subtypeCodec.serializedSize().isPresent();
143-
if (isVarSized) {
144-
ByteBuffer input = bytes.duplicate();
145-
List<SubtypeT> rv = new ArrayList<SubtypeT>(cqlType.getDimensions());
146-
for (int i = 0; i < cqlType.getDimensions(); ++i) {
147-
int size = VIntCoding.getUnsignedVInt32(input, input.position());
148-
input.position(input.position() + VIntCoding.computeUnsignedVIntSize(size));
149-
150-
ByteBuffer value;
151-
if (size < 0) {
152-
value = null;
153-
} else {
154-
value = input.duplicate();
155-
value.limit(value.position() + size);
156-
input.position(input.position() + size);
157-
}
158-
rv.add(subtypeCodec.decode(value, protocolVersion));
159-
}
160-
// if too many elements, throw
161-
if (input.hasRemaining()) {
162-
throw new IllegalArgumentException(
163-
String.format(
164-
"Too many elements; must provide elements for %d dimensions",
165-
cqlType.getDimensions()));
166-
}
167141

168-
return CqlVector.newInstance(rv);
169-
} else {
170-
int elementSize = subtypeCodec.serializedSize().get();
171-
if (bytes.remaining() != cqlType.getDimensions() * elementSize) {
172-
throw new IllegalArgumentException(
173-
String.format(
174-
"Expected elements of uniform size, observed %d elements with total bytes %d",
175-
cqlType.getDimensions(), bytes.remaining()));
176-
}
142+
// Upfront check for fixed-size types only
143+
subtypeCodec
144+
.serializedSize()
145+
.ifPresent(
146+
(fixed_size) -> {
147+
if (bytes.remaining() != cqlType.getDimensions() * fixed_size) {
148+
throw new IllegalArgumentException(
149+
String.format(
150+
"Expected elements of uniform size, observed %d elements with total bytes %d",
151+
cqlType.getDimensions(), bytes.remaining()));
152+
}
153+
});
154+
;
155+
ByteBuffer slice = bytes.slice();
156+
List<SubtypeT> rv = new ArrayList<SubtypeT>(cqlType.getDimensions());
157+
for (int i = 0; i < cqlType.getDimensions(); ++i) {
177158

178-
ByteBuffer slice = bytes.slice();
179-
List<SubtypeT> rv = new ArrayList<SubtypeT>(cqlType.getDimensions());
180-
for (int i = 0; i < cqlType.getDimensions(); ++i) {
181-
// Set the limit for the current element
182-
int originalPosition = slice.position();
183-
slice.limit(originalPosition + elementSize);
184-
rv.add(this.subtypeCodec.decode(slice, protocolVersion));
185-
// Move to the start of the next element
186-
slice.position(originalPosition + elementSize);
187-
// Reset the limit to the end of the buffer
188-
slice.limit(slice.capacity());
189-
}
190-
return CqlVector.newInstance(rv);
159+
int size =
160+
subtypeCodec
161+
.serializedSize()
162+
.orElseGet(() -> VIntCoding.getUnsignedVInt32(slice, slice.position()));
163+
// If we aren't dealing with a fixed-size type we need to move the current slice position
164+
// beyond the vint-encoded size of the current element. Ideally this would be
165+
// serializedSize().ifNotPresent(Consumer) but the Optional API isn't doing us any favors
166+
// there.
167+
if (!subtypeCodec.serializedSize().isPresent())
168+
slice.position(slice.position() + VIntCoding.computeUnsignedVIntSize(size));
169+
int originalPosition = slice.position();
170+
slice.limit(originalPosition + size);
171+
rv.add(this.subtypeCodec.decode(slice, protocolVersion));
172+
// Move to the start of the next element
173+
slice.position(originalPosition + size);
174+
// Reset the limit to the end of the buffer
175+
slice.limit(slice.capacity());
191176
}
177+
178+
// if too many elements, throw
179+
if (slice.hasRemaining()) {
180+
throw new IllegalArgumentException(
181+
String.format(
182+
"Too many elements; must provide elements for %d dimensions",
183+
cqlType.getDimensions()));
184+
}
185+
186+
return CqlVector.newInstance(rv);
192187
}
193188

194189
@NonNull

0 commit comments

Comments
 (0)