Skip to content

Commit 30df77a

Browse files
committed
Add reading support for byte element type
1 parent 8194f3c commit 30df77a

File tree

4 files changed

+244
-19
lines changed

4 files changed

+244
-19
lines changed

server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java

Lines changed: 169 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package org.elasticsearch.index.mapper;
1111

1212
import org.apache.lucene.index.BinaryDocValues;
13+
import org.apache.lucene.index.ByteVectorValues;
1314
import org.apache.lucene.index.DocValues;
1415
import org.apache.lucene.index.FloatVectorValues;
1516
import org.apache.lucene.index.KnnVectorValues;
@@ -508,10 +509,10 @@ public String toString() {
508509
}
509510
}
510511

511-
public static class DenseVectorBlockLoader extends DocValuesBlockLoader {
512+
public static class DenseVectorFloatBlockLoader extends DocValuesBlockLoader {
512513
private final String fieldName;
513514

514-
public DenseVectorBlockLoader(String fieldName) {
515+
public DenseVectorFloatBlockLoader(String fieldName) {
515516
this.fieldName = fieldName;
516517
}
517518

@@ -583,6 +584,83 @@ public String toString() {
583584
}
584585
}
585586

587+
public static class DenseVectorByteBlockLoader extends DocValuesBlockLoader {
588+
private final String fieldName;
589+
590+
public DenseVectorByteBlockLoader(String fieldName) {
591+
this.fieldName = fieldName;
592+
}
593+
594+
@Override
595+
public Builder builder(BlockFactory factory, int expectedCount) {
596+
return factory.bytesRefs(expectedCount);
597+
}
598+
599+
@Override
600+
public AllReader reader(LeafReaderContext context) throws IOException {
601+
ByteVectorValues byteVectorValues = context.reader().getByteVectorValues(fieldName);
602+
if (byteVectorValues != null) {
603+
return new ByteVectorValuesBlockReader(byteVectorValues);
604+
}
605+
return new ConstantNullsReader();
606+
}
607+
}
608+
609+
private static class ByteVectorValuesBlockReader extends BlockDocValuesReader {
610+
private final ByteVectorValues byteVectorValues;
611+
private final KnnVectorValues.DocIndexIterator iterator;
612+
613+
ByteVectorValuesBlockReader(ByteVectorValues byteVectorValues) {
614+
this.byteVectorValues = byteVectorValues;
615+
iterator = byteVectorValues.iterator();
616+
}
617+
618+
@Override
619+
public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException {
620+
// Doubles from doc values ensures that the values are in order
621+
try (BlockLoader.BytesRefBuilder builder = factory.bytesRefsFromDocValues(docs.count())) {
622+
for (int i = 0; i < docs.count(); i++) {
623+
int doc = docs.get(i);
624+
if (doc < iterator.docID()) {
625+
throw new IllegalStateException("docs within same block must be in order");
626+
}
627+
read(doc, builder);
628+
}
629+
return builder.build();
630+
}
631+
}
632+
633+
@Override
634+
public void read(int docId, BlockLoader.StoredFields storedFields, Builder builder) throws IOException {
635+
read(docId, (BlockLoader.BytesRefBuilder) builder);
636+
}
637+
638+
private void read(int doc, BlockLoader.BytesRefBuilder builder) throws IOException {
639+
if (iterator.advance(doc) == doc) {
640+
builder.beginPositionEntry();
641+
byte[] bytes = byteVectorValues.vectorValue(iterator.index());
642+
BytesRef scratch = new BytesRef(bytes, 0, 1);
643+
for (int i = 0; i < bytes.length; i++) {
644+
scratch.offset = i;
645+
builder.appendBytesRef(scratch);
646+
}
647+
builder.endPositionEntry();
648+
} else {
649+
builder.appendNull();
650+
}
651+
}
652+
653+
@Override
654+
public int docId() {
655+
return iterator.docID();
656+
}
657+
658+
@Override
659+
public String toString() {
660+
return "BlockDocValuesReader.FloatVectorValuesBlockReader";
661+
}
662+
}
663+
586664
public static class BytesRefsFromOrdsBlockLoader extends DocValuesBlockLoader {
587665
private final String fieldName;
588666

@@ -831,12 +909,12 @@ public String toString() {
831909
}
832910
}
833911

834-
public static class DenseVectorFromBinaryBlockLoader extends DocValuesBlockLoader {
912+
public static class DenseVectorFromFloatsBinaryBlockLoader extends DocValuesBlockLoader {
835913
private final String fieldName;
836914
private final int dims;
837915
private final IndexVersion indexVersion;
838916

839-
public DenseVectorFromBinaryBlockLoader(String fieldName, int dims, IndexVersion indexVersion) {
917+
public DenseVectorFromFloatsBinaryBlockLoader(String fieldName, int dims, IndexVersion indexVersion) {
840918
this.fieldName = fieldName;
841919
this.dims = dims;
842920
this.indexVersion = indexVersion;
@@ -853,18 +931,18 @@ public AllReader reader(LeafReaderContext context) throws IOException {
853931
if (docValues == null) {
854932
return new ConstantNullsReader();
855933
}
856-
return new DenseVectorFromBinary(docValues, dims, indexVersion);
934+
return new DenseVectorFromFloatsBinary(docValues, dims, indexVersion);
857935
}
858936
}
859937

860-
private static class DenseVectorFromBinary extends BlockDocValuesReader {
938+
private static class DenseVectorFromFloatsBinary extends BlockDocValuesReader {
861939
private final BinaryDocValues docValues;
862940
private final IndexVersion indexVersion;
863941
private final float[] scratch;
864942

865943
private int docID = -1;
866944

867-
DenseVectorFromBinary(BinaryDocValues docValues, int dims, IndexVersion indexVersion) {
945+
DenseVectorFromFloatsBinary(BinaryDocValues docValues, int dims, IndexVersion indexVersion) {
868946
this.docValues = docValues;
869947
this.scratch = new float[dims];
870948
this.indexVersion = indexVersion;
@@ -913,7 +991,90 @@ public int docId() {
913991

914992
@Override
915993
public String toString() {
916-
return "DenseVectorFromBinary.Bytes";
994+
return "DenseVectorFromFloatsBinary.Bytes";
995+
}
996+
}
997+
998+
public static class DenseVectorFromFBytesBinaryBlockLoader extends DocValuesBlockLoader {
999+
private final String fieldName;
1000+
private final int dims;
1001+
1002+
public DenseVectorFromFBytesBinaryBlockLoader(String fieldName, int dims) {
1003+
this.fieldName = fieldName;
1004+
this.dims = dims;
1005+
}
1006+
1007+
@Override
1008+
public Builder builder(BlockFactory factory, int expectedCount) {
1009+
return factory.bytesRefs(expectedCount);
1010+
}
1011+
1012+
@Override
1013+
public AllReader reader(LeafReaderContext context) throws IOException {
1014+
BinaryDocValues docValues = context.reader().getBinaryDocValues(fieldName);
1015+
if (docValues == null) {
1016+
return new ConstantNullsReader();
1017+
}
1018+
return new DenseVectorFromBytesBinary(docValues, dims);
1019+
}
1020+
}
1021+
1022+
private static class DenseVectorFromBytesBinary extends BlockDocValuesReader {
1023+
private final BinaryDocValues docValues;
1024+
private final int dims;
1025+
1026+
private int docID = -1;
1027+
1028+
DenseVectorFromBytesBinary(BinaryDocValues docValues, int dims) {
1029+
this.docValues = docValues;
1030+
this.dims = dims;
1031+
}
1032+
1033+
@Override
1034+
public BlockLoader.Block read(BlockFactory factory, Docs docs) throws IOException {
1035+
try (BlockLoader.BytesRefBuilder builder = factory.bytesRefs(docs.count())) {
1036+
for (int i = 0; i < docs.count(); i++) {
1037+
int doc = docs.get(i);
1038+
if (doc < docID) {
1039+
throw new IllegalStateException("docs within same block must be in order");
1040+
}
1041+
read(doc, builder);
1042+
}
1043+
return builder.build();
1044+
}
1045+
}
1046+
1047+
@Override
1048+
public void read(int docId, BlockLoader.StoredFields storedFields, Builder builder) throws IOException {
1049+
read(docId, (BlockLoader.BytesRefBuilder) builder);
1050+
}
1051+
1052+
private void read(int doc, BlockLoader.BytesRefBuilder builder) throws IOException {
1053+
this.docID = doc;
1054+
if (false == docValues.advanceExact(doc)) {
1055+
builder.appendNull();
1056+
return;
1057+
}
1058+
BytesRef bytesRef = docValues.binaryValue();
1059+
assert bytesRef.length > 0;
1060+
1061+
builder.beginPositionEntry();
1062+
BytesRef scratch = new BytesRef(bytesRef.bytes, 0, 1);
1063+
for (int i = 0; i < dims; i++) {
1064+
scratch.offset = i;
1065+
builder.appendBytesRef(scratch);
1066+
}
1067+
builder.endPositionEntry();
1068+
}
1069+
1070+
@Override
1071+
public int docId() {
1072+
return docID;
1073+
}
1074+
1075+
@Override
1076+
public String toString() {
1077+
return "DenseVectorFromBytesBinary.Bytes";
9171078
}
9181079
}
9191080

server/src/main/java/org/elasticsearch/index/mapper/BlockSourceReader.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,45 @@ public String toString() {
341341
}
342342
}
343343

344+
public static class BytesBlockLoader extends SourceBlockLoader {
345+
public BytesBlockLoader(ValueFetcher fetcher, LeafIteratorLookup lookup) {
346+
super(fetcher, lookup);
347+
}
348+
349+
@Override
350+
public Builder builder(BlockFactory factory, int expectedCount) {
351+
return factory.bytesRefs(expectedCount);
352+
}
353+
354+
@Override
355+
public RowStrideReader rowStrideReader(LeafReaderContext context, DocIdSetIterator iter) {
356+
return new Bytes(fetcher, iter);
357+
}
358+
359+
@Override
360+
protected String name() {
361+
return "Bytes";
362+
}
363+
}
364+
365+
private static class Bytes extends BlockSourceReader {
366+
private BytesRef scratch = new BytesRef(1);
367+
368+
Bytes(ValueFetcher fetcher, DocIdSetIterator iter) {
369+
super(fetcher, iter);
370+
}
371+
372+
@Override
373+
protected void append(BlockLoader.Builder builder, Object v) {
374+
((BlockLoader.BytesRefBuilder) builder).appendBytesRef(new BytesRef(((Number) v).byteValue()));
375+
}
376+
377+
@Override
378+
public String toString() {
379+
return "BlockSourceReader.Bytes";
380+
}
381+
}
382+
344383
/**
345384
* Load {@code int}s from {@code _source}.
346385
*/

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -434,17 +434,30 @@ public double computeSquaredMagnitude(VectorData vectorData) {
434434

435435
@Override
436436
public DocValuesBlockLoader indexedBlockLoader(String fieldName) {
437-
return null;
437+
return new BlockDocValuesReader.DenseVectorByteBlockLoader(fieldName);
438438
}
439439

440440
@Override
441441
public DocValuesBlockLoader docValuesBlockLoader(String fieldName, int dimensions, IndexVersion indexVersionCreated) {
442-
return null;
442+
return new BlockDocValuesReader.DenseVectorFromFBytesBinaryBlockLoader(fieldName, dimensions);
443443
}
444444

445445
@Override
446446
public BlockLoader sourceBlockLoader(String fieldName, MappedFieldType.BlockLoaderContext blContext) {
447-
return null;
447+
BlockSourceReader.LeafIteratorLookup lookup = BlockSourceReader.lookupMatchingAll();
448+
return new BlockSourceReader.FloatsBlockLoader(sourceValueFetcher(blContext.sourcePaths(fieldName)), lookup);
449+
}
450+
451+
private SourceValueFetcher sourceValueFetcher(Set<String> sourcePaths) {
452+
return new SourceValueFetcher(sourcePaths, null) {
453+
@Override
454+
protected Object parseSourceValue(Object value) {
455+
if (value.equals("")) {
456+
return null;
457+
}
458+
return NumberFieldMapper.NumberType.BYTE.parse(value, false);
459+
}
460+
};
448461
}
449462

450463
private VectorData parseVectorArray(
@@ -690,12 +703,12 @@ public double computeSquaredMagnitude(VectorData vectorData) {
690703

691704
@Override
692705
public DocValuesBlockLoader indexedBlockLoader(String fieldName) {
693-
return new BlockDocValuesReader.DenseVectorBlockLoader(fieldName);
706+
return new BlockDocValuesReader.DenseVectorFloatBlockLoader(fieldName);
694707
}
695708

696709
@Override
697710
public DocValuesBlockLoader docValuesBlockLoader(String fieldName, int dimensions, IndexVersion indexVersionCreated) {
698-
return new BlockDocValuesReader.DenseVectorFromBinaryBlockLoader(fieldName, dimensions, indexVersionCreated);
711+
return new BlockDocValuesReader.DenseVectorFromFloatsBinaryBlockLoader(fieldName, dimensions, indexVersionCreated);
699712
}
700713

701714
@Override

server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@
6666

6767
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
6868
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
69-
import static org.hamcrest.Matchers.closeTo;
7069
import static org.hamcrest.Matchers.containsString;
7170
import static org.hamcrest.Matchers.equalTo;
7271
import static org.hamcrest.Matchers.instanceOf;
@@ -2239,7 +2238,7 @@ protected IngestScriptSupport ingestScriptSupport() {
22392238

22402239
@Override
22412240
protected SyntheticSourceSupport syntheticSourceSupport(boolean ignoreMalformed) {
2242-
return new DenseVectorSyntheticSourceSupport();
2241+
return new DenseVectorSyntheticSourceSupport(elementType, dims);
22432242
}
22442243

22452244
@Override
@@ -2255,7 +2254,7 @@ protected BlockReaderSupport getSupportedReaders(MapperService mapper, String lo
22552254
@Override
22562255
protected Function<Object, Object> loadBlockExpected(BlockReaderSupport blockReaderSupport, boolean columnReader) {
22572256
DenseVectorFieldType ft = (DenseVectorFieldType) blockReaderSupport.mapper().fieldType(blockReaderSupport.loaderFieldName());
2258-
if (ft.getElementType() != ElementType.FLOAT) {
2257+
if (ft.getElementType() == ElementType.BIT) {
22592258
return null;
22602259
}
22612260

@@ -2264,15 +2263,28 @@ protected Function<Object, Object> loadBlockExpected(BlockReaderSupport blockRea
22642263

22652264
@Override
22662265
protected Matcher<?> blockItemMatcher(Object expected) {
2267-
return equalTo(((Double) expected).floatValue());
2266+
Number expectedNumber = (Number) expected;
2267+
switch (elementType) {
2268+
case ElementType.FLOAT:
2269+
return equalTo(expectedNumber.floatValue());
2270+
case ElementType.BYTE, ElementType.BIT:
2271+
return equalTo(new BytesRef(new byte[] { expectedNumber.byteValue() }));
2272+
}
2273+
fail("Unexpected element type: " + elementType);
2274+
return null;
22682275
}
22692276

22702277
private static class DenseVectorSyntheticSourceSupport implements SyntheticSourceSupport {
2271-
private final int dims = between(5, 1000);
2272-
private final ElementType elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BIT);
2278+
private final int dims;
2279+
private final ElementType elementType;
22732280
private final boolean indexed = randomBoolean();
22742281
private final boolean indexOptionsSet = indexed && randomBoolean();
22752282

2283+
DenseVectorSyntheticSourceSupport(ElementType elementType, int dims) {
2284+
this.elementType = elementType;
2285+
this.dims = dims;
2286+
}
2287+
22762288
@Override
22772289
public SyntheticSourceExample example(int maxValues) throws IOException {
22782290
Object value = switch (elementType) {

0 commit comments

Comments
 (0)