Skip to content

Commit 9f2c96f

Browse files
Fix bugs
1 parent e90fcd7 commit 9f2c96f

File tree

4 files changed

+12
-18
lines changed

4 files changed

+12
-18
lines changed

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3018,7 +3018,9 @@ public KnnVectorsFormat getKnnVectorsFormatForField(KnnVectorsFormat defaultForm
30183018
List<KnnVectorsFormat> extraKnnFormats = new ArrayList<>();
30193019
for (VectorsFormatProvider vectorsFormatProvider : extraVectorsFormatProviders) {
30203020
KnnVectorsFormat extraKnnFormat = vectorsFormatProvider.getKnnVectorsFormat(indexSettings, indexOptions);
3021-
extraKnnFormats.add(extraKnnFormat);
3021+
if (extraKnnFormat != null) {
3022+
extraKnnFormats.add(extraKnnFormat);
3023+
}
30223024
}
30233025
if (extraKnnFormats.size() > 0) {
30243026
format = extraKnnFormats.get(0);

server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorsFormatProvider.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ public interface VectorsFormatProvider {
2020

2121
/**
2222
* Returns a {@link KnnVectorsFormat} instance based on the provided index settings and vector index options.
23+
* May return {@code null} if the provider does not support the format for the given index settings or vector index options.
2324
*
2425
* @param indexSettings The index settings.
2526
* @param indexOptions The dense vector index options.

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/GPUToHNSWVectorsWriter.java

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
import org.elasticsearch.logging.Logger;
4444

4545
import java.io.IOException;
46-
import java.io.UncheckedIOException;
4746
import java.nio.ByteBuffer;
4847
import java.nio.ByteOrder;
4948
import java.util.ArrayList;
@@ -470,7 +469,6 @@ public NodesIterator getNodesOnLevel(int level) {
470469
@SuppressForbidden(reason = "require usage of Lucene's IOUtils#deleteFilesIgnoringExceptions(...)")
471470
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
472471
flatVectorWriter.mergeOneField(fieldInfo, mergeState);
473-
FloatVectorValues vectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
474472
// save merged vector values to a temp file
475473
final int numVectors;
476474
String tempRawVectorsFileName = null;
@@ -487,9 +485,8 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
487485
}
488486
try (IndexInput in = mergeState.segmentInfo.dir.openInput(tempRawVectorsFileName, IOContext.DEFAULT)) {
489487
DatasetOrVectors datasetOrVectors;
490-
491488
var input = FilterIndexInput.unwrapOnlyTest(in);
492-
if (input instanceof MemorySegmentAccessInput memorySegmentAccessInput) {
489+
if (input instanceof MemorySegmentAccessInput memorySegmentAccessInput && numVectors >= MIN_NUM_VECTORS_FOR_GPU_BUILD) {
493490
var ds = DatasetUtils.getInstance().fromInput(memorySegmentAccessInput, numVectors, fieldInfo.getVectorDimension());
494491
datasetOrVectors = DatasetOrVectors.fromDataset(ds);
495492
} else {
@@ -521,7 +518,6 @@ private static int writeFloatVectorValues(FieldInfo fieldInfo, IndexOutput out,
521518
for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) {
522519
numVectors++;
523520
float[] vector = floatVectorValues.vectorValue(iterator.index());
524-
out.writeInt(iterator.docID());
525521
buffer.asFloatBuffer().put(vector);
526522
out.writeBytes(buffer.array(), buffer.array().length);
527523
}
@@ -532,12 +528,12 @@ private static FloatVectorValues getFloatVectorValues(FieldInfo fieldInfo, Index
532528
if (numVectors == 0) {
533529
return FloatVectorValues.fromFloats(List.of(), fieldInfo.getVectorDimension());
534530
}
535-
final long length = (long) Float.BYTES * fieldInfo.getVectorDimension() + Integer.BYTES;
531+
final long length = (long) Float.BYTES * fieldInfo.getVectorDimension();
536532
final float[] vector = new float[fieldInfo.getVectorDimension()];
537533
return new FloatVectorValues() {
538534
@Override
539535
public float[] vectorValue(int ord) throws IOException {
540-
randomAccessInput.seek(ord * length + Integer.BYTES);
536+
randomAccessInput.seek(ord * length);
541537
randomAccessInput.readFloats(vector, 0, vector.length);
542538
return vector;
543539
}
@@ -559,12 +555,7 @@ public int size() {
559555

560556
@Override
561557
public int ordToDoc(int ord) {
562-
try {
563-
randomAccessInput.seek(ord * length);
564-
return randomAccessInput.readInt();
565-
} catch (IOException e) {
566-
throw new UncheckedIOException(e);
567-
}
558+
throw new UnsupportedOperationException("Not implemented");
568559
}
569560
};
570561
}

x-pack/plugin/gpu/src/main21/java/org/elasticsearch/xpack/gpu/codec/DatasetUtilsImpl.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,23 +49,23 @@ public Dataset fromInput(MemorySegmentAccessInput input, int numVectors, int dim
4949
}
5050
MemorySegment ms = input.segmentSliceOrNull(0L, input.length());
5151
assert ms != null; // TODO: this can be null if larger than 16GB or ...
52-
if (((long) numVectors * dims * Float.BYTES) < ms.byteSize()) {
52+
if (((long) numVectors * dims * Float.BYTES) > ms.byteSize()) {
5353
throwIllegalArgumentException(ms, numVectors, dims);
5454
}
5555
return fromMemorySegment(ms, numVectors, dims);
5656
}
5757

5858
static void throwIllegalArgumentException(MemorySegment ms, int numVectors, int dims) {
59-
var s = "segment of size [" + ms.byteSize() + "] too small for expected " + numVectors + " float vectors of " + dims + "dimensions";
59+
var s = "segment of size [" + ms.byteSize() + "] too small for expected " + numVectors + " float vectors of " + dims + " dims";
6060
throw new IllegalArgumentException(s);
6161
}
6262

6363
static void throwIllegalArgumentException(int numVectors, int dims) {
6464
String s;
6565
if (numVectors < 0) {
66-
s = "negative number of vectors:" + numVectors;
66+
s = "negative number of vectors: " + numVectors;
6767
} else {
68-
s = "negative vector dims:" + dims;
68+
s = "negative vector dims: " + dims;
6969
}
7070
throw new IllegalArgumentException(s);
7171
}

0 commit comments

Comments
 (0)