Skip to content

Commit ebb36e0

Browse files
Fix for int8_hnsw when the number of vectors < MIN_NUM_VECTORS_FOR_GPU_BUILD after merge
1 parent cff3e94 commit ebb36e0

File tree

1 file changed

+14
-64
lines changed

1 file changed

+14
-64
lines changed

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/ESGpuHnswVectorsWriter.java

Lines changed: 14 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -242,18 +242,14 @@ private void writeFieldInternal(FieldInfo fieldInfo, DatasetOrVectors datasetOrV
242242
long vectorIndexOffset = vectorIndex.getFilePointer();
243243
int[][] graphLevelNodeOffsets = new int[1][];
244244
HnswGraph mockGraph;
245-
if (datasetOrVectors.vectors != null) {
246-
float[][] vectors = datasetOrVectors.vectors;
245+
if (datasetOrVectors.getVectors() != null) {
246+
int size = datasetOrVectors.size();
247247
if (logger.isDebugEnabled()) {
248-
logger.debug(
249-
"Skip building carga index; vectors length {} < {} (min for GPU)",
250-
vectors.length,
251-
MIN_NUM_VECTORS_FOR_GPU_BUILD
252-
);
248+
logger.debug("Skip building carga index; vectors length {} < {} (min for GPU)", size, MIN_NUM_VECTORS_FOR_GPU_BUILD);
253249
}
254-
mockGraph = writeGraph(vectors, graphLevelNodeOffsets);
250+
mockGraph = writeGraph(size, graphLevelNodeOffsets);
255251
} else {
256-
var dataset = datasetOrVectors.dataset;
252+
var dataset = datasetOrVectors.getDataset();
257253
var cuVSResources = cuVSResourceManager.acquire((int) dataset.size(), (int) dataset.columns());
258254
try {
259255
try (var index = buildGPUIndex(cuVSResources, fieldInfo.getVectorSimilarityFunction(), dataset)) {
@@ -340,13 +336,12 @@ private HnswGraph writeGraph(CuVSMatrix cagraGraph, int[][] levelNodeOffsets) th
340336
return createMockGraph(maxElementCount, maxGraphDegree);
341337
}
342338

343-
// create a graph where every node is connected to every other node
344-
private HnswGraph writeGraph(float[][] vectors, int[][] levelNodeOffsets) throws IOException {
345-
if (vectors.length == 0) {
339+
// create a mock graph where every node is connected to every other node
340+
private HnswGraph writeGraph(int elementCount, int[][] levelNodeOffsets) throws IOException {
341+
if (elementCount == 0) {
346342
return null;
347343
}
348-
int elementCount = vectors.length;
349-
int nodeDegree = vectors.length - 1;
344+
int nodeDegree = elementCount - 1;
350345
levelNodeOffsets[0] = new int[elementCount];
351346

352347
int[] neighbors = new int[nodeDegree];
@@ -447,9 +442,11 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE
447442
.fromInput(memorySegmentAccessInput, numVectors, fieldInfo.getVectorDimension(), dataType);
448443
datasetOrVectors = DatasetOrVectors.fromDataset(ds);
449444
} else {
450-
// TODO fix for byte vectors
451-
var fa = copyVectorsIntoArray(in, fieldInfo, numVectors);
452-
datasetOrVectors = DatasetOrVectors.fromArray(fa);
445+
assert numVectors < MIN_NUM_VECTORS_FOR_GPU_BUILD : "numVectors: " + numVectors;
446+
// we don't really need real value for vectors here,
447+
// we just build a mock graph where every node is connected to every other node
448+
float[][] vectors = new float[numVectors][fieldInfo.getVectorDimension()];
449+
datasetOrVectors = DatasetOrVectors.fromArray(vectors);
453450
}
454451
writeFieldInternal(fieldInfo, datasetOrVectors);
455452
} finally {
@@ -482,17 +479,6 @@ private static int writeByteVectorValues(IndexOutput out, ByteVectorValues vecto
482479
return numVectors;
483480
}
484481

485-
static float[][] copyVectorsIntoArray(IndexInput in, FieldInfo fieldInfo, int numVectors) throws IOException {
486-
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, in, numVectors);
487-
float[][] vectors = new float[numVectors][fieldInfo.getVectorDimension()];
488-
float[] vector;
489-
for (int i = 0; i < numVectors; i++) {
490-
vector = floatVectorValues.vectorValue(i);
491-
System.arraycopy(vector, 0, vectors[i], 0, vector.length);
492-
}
493-
return vectors;
494-
}
495-
496482
private static int writeFloatVectorValues(FieldInfo fieldInfo, IndexOutput out, FloatVectorValues floatVectorValues)
497483
throws IOException {
498484
int numVectors = 0;
@@ -507,42 +493,6 @@ private static int writeFloatVectorValues(FieldInfo fieldInfo, IndexOutput out,
507493
return numVectors;
508494
}
509495

510-
private static FloatVectorValues getFloatVectorValues(FieldInfo fieldInfo, IndexInput randomAccessInput, int numVectors) {
511-
if (numVectors == 0) {
512-
return FloatVectorValues.fromFloats(List.of(), fieldInfo.getVectorDimension());
513-
}
514-
final long length = (long) Float.BYTES * fieldInfo.getVectorDimension();
515-
final float[] vector = new float[fieldInfo.getVectorDimension()];
516-
return new FloatVectorValues() {
517-
@Override
518-
public float[] vectorValue(int ord) throws IOException {
519-
randomAccessInput.seek(ord * length);
520-
randomAccessInput.readFloats(vector, 0, vector.length);
521-
return vector;
522-
}
523-
524-
@Override
525-
public FloatVectorValues copy() {
526-
return this;
527-
}
528-
529-
@Override
530-
public int dimension() {
531-
return fieldInfo.getVectorDimension();
532-
}
533-
534-
@Override
535-
public int size() {
536-
return numVectors;
537-
}
538-
539-
@Override
540-
public int ordToDoc(int ord) {
541-
throw new UnsupportedOperationException("Not implemented");
542-
}
543-
};
544-
}
545-
546496
private void writeMeta(
547497
FieldInfo field,
548498
long vectorIndexOffset,

0 commit comments

Comments
 (0)