Skip to content

Commit 26ffd7f

Browse files
authored
ESQL - Add byte element support for dense_vector data type (elastic#131863)
1 parent 4e3602d commit 26ffd7f

File tree

11 files changed

+452
-178
lines changed

11 files changed

+452
-178
lines changed

server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java

Lines changed: 142 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package org.elasticsearch.index.mapper;
1111

1212
import org.apache.lucene.index.BinaryDocValues;
13+
import org.apache.lucene.index.ByteVectorValues;
1314
import org.apache.lucene.index.DocValues;
1415
import org.apache.lucene.index.FloatVectorValues;
1516
import org.apache.lucene.index.KnnVectorValues;
@@ -30,11 +31,15 @@
3031
import org.elasticsearch.index.mapper.BlockLoader.DoubleBuilder;
3132
import org.elasticsearch.index.mapper.BlockLoader.IntBuilder;
3233
import org.elasticsearch.index.mapper.BlockLoader.LongBuilder;
34+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
35+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
3336
import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder;
3437
import org.elasticsearch.search.fetch.StoredFieldsSpec;
3538

3639
import java.io.IOException;
3740

41+
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType.BYTE;
42+
3843
/**
3944
* A reader that supports reading doc-values from a Lucene segment in Block fashion.
4045
*/
@@ -516,10 +521,12 @@ public String toString() {
516521
public static class DenseVectorBlockLoader extends DocValuesBlockLoader {
517522
private final String fieldName;
518523
private final int dimensions;
524+
private final DenseVectorFieldMapper.DenseVectorFieldType fieldType;
519525

520-
public DenseVectorBlockLoader(String fieldName, int dimensions) {
526+
public DenseVectorBlockLoader(String fieldName, int dimensions, DenseVectorFieldMapper.DenseVectorFieldType fieldType) {
521527
this.fieldName = fieldName;
522528
this.dimensions = dimensions;
529+
this.fieldType = fieldType;
523530
}
524531

525532
@Override
@@ -529,22 +536,34 @@ public Builder builder(BlockFactory factory, int expectedCount) {
529536

530537
@Override
531538
public AllReader reader(LeafReaderContext context) throws IOException {
532-
FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName);
533-
if (floatVectorValues != null) {
534-
return new DenseVectorValuesBlockReader(floatVectorValues, dimensions);
539+
switch (fieldType.getElementType()) {
540+
case FLOAT -> {
541+
FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName);
542+
if (floatVectorValues != null) {
543+
return new FloatDenseVectorValuesBlockReader(floatVectorValues, dimensions);
544+
}
545+
}
546+
case BYTE -> {
547+
ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(fieldName);
548+
if (byteVectorValues != null) {
549+
return new ByteDenseVectorValuesBlockReader(byteVectorValues, dimensions);
550+
}
551+
}
535552
}
553+
536554
return new ConstantNullsReader();
537555
}
538556
}
539557

540-
private static class DenseVectorValuesBlockReader extends BlockDocValuesReader {
541-
private final FloatVectorValues floatVectorValues;
542-
private final KnnVectorValues.DocIndexIterator iterator;
543-
private final int dimensions;
558+
private abstract static class DenseVectorValuesBlockReader<T extends KnnVectorValues> extends BlockDocValuesReader {
559+
560+
protected final T vectorValues;
561+
protected final KnnVectorValues.DocIndexIterator iterator;
562+
protected final int dimensions;
544563

545-
DenseVectorValuesBlockReader(FloatVectorValues floatVectorValues, int dimensions) {
546-
this.floatVectorValues = floatVectorValues;
547-
iterator = floatVectorValues.iterator();
564+
DenseVectorValuesBlockReader(T vectorValues, int dimensions) {
565+
this.vectorValues = vectorValues;
566+
iterator = vectorValues.iterator();
548567
this.dimensions = dimensions;
549568
}
550569

@@ -569,26 +588,59 @@ private void read(int doc, BlockLoader.FloatBuilder builder) throws IOException
569588
builder.appendNull();
570589
} else if (iterator.docID() == doc || iterator.advance(doc) == doc) {
571590
builder.beginPositionEntry();
572-
float[] floats = floatVectorValues.vectorValue(iterator.index());
573-
assert floats.length == dimensions
574-
: "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length;
575-
for (float aFloat : floats) {
576-
builder.appendFloat(aFloat);
577-
}
591+
appendDoc(builder);
578592
builder.endPositionEntry();
579593
} else {
580594
builder.appendNull();
581595
}
582596
}
583597

598+
protected abstract void appendDoc(BlockLoader.FloatBuilder builder) throws IOException;
599+
584600
@Override
585601
public int docId() {
586602
return iterator.docID();
587603
}
604+
}
605+
606+
private static class FloatDenseVectorValuesBlockReader extends DenseVectorValuesBlockReader<FloatVectorValues> {
607+
608+
FloatDenseVectorValuesBlockReader(FloatVectorValues floatVectorValues, int dimensions) {
609+
super(floatVectorValues, dimensions);
610+
}
611+
612+
protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException {
613+
float[] floats = vectorValues.vectorValue(iterator.index());
614+
assert floats.length == dimensions
615+
: "unexpected dimensions for vector value; expected " + dimensions + " but got " + floats.length;
616+
for (float aFloat : floats) {
617+
builder.appendFloat(aFloat);
618+
}
619+
}
588620

589621
@Override
590622
public String toString() {
591-
return "BlockDocValuesReader.FloatVectorValuesBlockReader";
623+
return "BlockDocValuesReader.FloatDenseVectorValuesBlockReader";
624+
}
625+
}
626+
627+
private static class ByteDenseVectorValuesBlockReader extends DenseVectorValuesBlockReader<ByteVectorValues> {
628+
ByteDenseVectorValuesBlockReader(ByteVectorValues floatVectorValues, int dimensions) {
629+
super(floatVectorValues, dimensions);
630+
}
631+
632+
protected void appendDoc(BlockLoader.FloatBuilder builder) throws IOException {
633+
byte[] bytes = vectorValues.vectorValue(iterator.index());
634+
assert bytes.length == dimensions
635+
: "unexpected dimensions for vector value; expected " + dimensions + " but got " + bytes.length;
636+
for (byte aFloat : bytes) {
637+
builder.appendFloat(aFloat);
638+
}
639+
}
640+
641+
@Override
642+
public String toString() {
643+
return "BlockDocValuesReader.ByteDenseVectorValuesBlockReader";
592644
}
593645
}
594646

@@ -880,11 +932,13 @@ public static class DenseVectorFromBinaryBlockLoader extends DocValuesBlockLoade
880932
private final String fieldName;
881933
private final int dims;
882934
private final IndexVersion indexVersion;
935+
private final ElementType elementType;
883936

884-
public DenseVectorFromBinaryBlockLoader(String fieldName, int dims, IndexVersion indexVersion) {
937+
public DenseVectorFromBinaryBlockLoader(String fieldName, int dims, IndexVersion indexVersion, ElementType elementType) {
885938
this.fieldName = fieldName;
886939
this.dims = dims;
887940
this.indexVersion = indexVersion;
941+
this.elementType = elementType;
888942
}
889943

890944
@Override
@@ -898,23 +952,40 @@ public AllReader reader(LeafReaderContext context) throws IOException {
898952
if (docValues == null) {
899953
return new ConstantNullsReader();
900954
}
901-
return new DenseVectorFromBinary(docValues, dims, indexVersion);
955+
switch (elementType) {
956+
case FLOAT:
957+
return new FloatDenseVectorFromBinary(docValues, dims, indexVersion);
958+
case BYTE:
959+
return new ByteDenseVectorFromBinary(docValues, dims, indexVersion);
960+
default:
961+
throw new IllegalArgumentException("Unknown element type [" + elementType + "]");
962+
}
902963
}
903964
}
904965

905-
private static class DenseVectorFromBinary extends BlockDocValuesReader {
906-
private final BinaryDocValues docValues;
907-
private final IndexVersion indexVersion;
908-
private final int dimensions;
909-
private final float[] scratch;
910-
911-
private int docID = -1;
966+
// Abstract base for dense vector readers
967+
private abstract static class AbstractDenseVectorFromBinary<T> extends BlockDocValuesReader {
968+
protected final BinaryDocValues docValues;
969+
protected final IndexVersion indexVersion;
970+
protected final int dimensions;
971+
protected final T scratch;
972+
protected int docID = -1;
912973

913-
DenseVectorFromBinary(BinaryDocValues docValues, int dims, IndexVersion indexVersion) {
974+
AbstractDenseVectorFromBinary(BinaryDocValues docValues, int dims, IndexVersion indexVersion, T scratch) {
914975
this.docValues = docValues;
915-
this.scratch = new float[dims];
916976
this.indexVersion = indexVersion;
917977
this.dimensions = dims;
978+
this.scratch = scratch;
979+
}
980+
981+
@Override
982+
public int docId() {
983+
return docID;
984+
}
985+
986+
@Override
987+
public void read(int docId, BlockLoader.StoredFields storedFields, Builder builder) throws IOException {
988+
read(docId, (BlockLoader.FloatBuilder) builder);
918989
}
919990

920991
@Override
@@ -931,36 +1002,67 @@ public BlockLoader.Block read(BlockFactory factory, Docs docs, int offset) throw
9311002
}
9321003
}
9331004

934-
@Override
935-
public void read(int docId, BlockLoader.StoredFields storedFields, Builder builder) throws IOException {
936-
read(docId, (BlockLoader.FloatBuilder) builder);
937-
}
938-
9391005
private void read(int doc, BlockLoader.FloatBuilder builder) throws IOException {
9401006
this.docID = doc;
941-
if (false == docValues.advanceExact(doc)) {
1007+
if (docValues.advanceExact(doc) == false) {
9421008
builder.appendNull();
9431009
return;
9441010
}
9451011
BytesRef bytesRef = docValues.binaryValue();
9461012
assert bytesRef.length > 0;
947-
VectorEncoderDecoder.decodeDenseVector(indexVersion, bytesRef, scratch);
1013+
decodeDenseVector(bytesRef, scratch);
9481014

9491015
builder.beginPositionEntry();
1016+
writeScratchToBuilder(scratch, builder);
1017+
builder.endPositionEntry();
1018+
}
1019+
1020+
protected abstract void decodeDenseVector(BytesRef bytesRef, T scratch);
1021+
1022+
protected abstract void writeScratchToBuilder(T scratch, BlockLoader.FloatBuilder builder);
1023+
}
1024+
1025+
private static class FloatDenseVectorFromBinary extends AbstractDenseVectorFromBinary<float[]> {
1026+
FloatDenseVectorFromBinary(BinaryDocValues docValues, int dims, IndexVersion indexVersion) {
1027+
super(docValues, dims, indexVersion, new float[dims]);
1028+
}
1029+
1030+
@Override
1031+
protected void writeScratchToBuilder(float[] scratch, BlockLoader.FloatBuilder builder) {
9501032
for (float value : scratch) {
9511033
builder.appendFloat(value);
9521034
}
953-
builder.endPositionEntry();
9541035
}
9551036

9561037
@Override
957-
public int docId() {
958-
return docID;
1038+
protected void decodeDenseVector(BytesRef bytesRef, float[] scratch) {
1039+
VectorEncoderDecoder.decodeDenseVector(indexVersion, bytesRef, scratch);
9591040
}
9601041

9611042
@Override
9621043
public String toString() {
963-
return "DenseVectorFromBinary.Bytes";
1044+
return "FloatDenseVectorFromBinary.Bytes";
1045+
}
1046+
}
1047+
1048+
private static class ByteDenseVectorFromBinary extends AbstractDenseVectorFromBinary<byte[]> {
1049+
ByteDenseVectorFromBinary(BinaryDocValues docValues, int dims, IndexVersion indexVersion) {
1050+
super(docValues, dims, indexVersion, new byte[dims]);
1051+
}
1052+
1053+
@Override
1054+
public String toString() {
1055+
return "ByteDenseVectorFromBinary.Bytes";
1056+
}
1057+
1058+
protected void writeScratchToBuilder(byte[] scratch, BlockLoader.FloatBuilder builder) {
1059+
for (byte value : scratch) {
1060+
builder.appendFloat(value);
1061+
}
1062+
}
1063+
1064+
protected void decodeDenseVector(BytesRef bytesRef, byte[] scratch) {
1065+
VectorEncoderDecoder.decodeDenseVector(indexVersion, bytesRef, scratch);
9641066
}
9651067
}
9661068

0 commit comments

Comments
 (0)