Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
206022a
added classes related to running hierarchical kmeans as a clustering …
john-wagster May 30, 2025
85e4d8f
[CI] Auto commit changes from spotless
May 30, 2025
6578e87
Merge branch 'main' into ivf_hkmeans
john-wagster May 30, 2025
4280682
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 2, 2025
58c5991
iter
john-wagster Jun 2, 2025
651efdf
[CI] Auto commit changes from spotless
Jun 2, 2025
5743d59
bringing back some interfaces
john-wagster Jun 2, 2025
47e5d8e
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 2, 2025
786e4f1
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 2, 2025
5ca53d3
accidentally remove suppressforbidden
john-wagster Jun 2, 2025
b1f9ae4
migrated from short to int and fixed IOUtils copy/paste errors
john-wagster Jun 2, 2025
075e2ce
no longer allocating larger arrays for slices that are the entire set…
john-wagster Jun 2, 2025
5fb98ff
[CI] Auto commit changes from spotless
Jun 2, 2025
523c2ca
iter on fvvs
john-wagster Jun 2, 2025
bb4531b
Merge branch 'ivf_hkmeans' of github.com:john-wagster/elasticsearch i…
john-wagster Jun 2, 2025
44b0aa9
iter on fvvs
john-wagster Jun 2, 2025
f5f0538
fixing comment
john-wagster Jun 2, 2025
3893098
switched to reservoir sampling
john-wagster Jun 3, 2025
1f2d053
switched to reservoir sampling
john-wagster Jun 3, 2025
6cda6a6
switched to reservoir sampling
john-wagster Jun 3, 2025
4cd94cf
missed a few short to int in tests
john-wagster Jun 3, 2025
b6d61fa
removed sorting on writeCentroids
john-wagster Jun 3, 2025
c82d719
migrated CentroidAssignments to a class to hide default constructor, …
john-wagster Jun 3, 2025
4bd2c9c
only getting the vector value on sampling when necessary
john-wagster Jun 3, 2025
1d61944
* stepLloyd now passes nextCentroids to prevent creating and rec…
john-wagster Jun 4, 2025
f05a541
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 4, 2025
26698d7
[CI] Auto commit changes from spotless
Jun 4, 2025
5112408
bug fixes around printing cluster metrics; still refactoring this
john-wagster Jun 4, 2025
dd61ba5
split kmeansresult into two classes, updated centroid assignments int…
john-wagster Jun 5, 2025
762839e
comibned kmeans and kmeanslocal classes into one class, and fixed vis…
john-wagster Jun 5, 2025
e5746a1
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 5, 2025
44d0f24
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 6, 2025
e82af9c
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 6, 2025
cf7c6b3
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 6, 2025
2c96c82
added trimtosize and fixed a spot where we should be returning KMeans…
john-wagster Jun 6, 2025
cc5570a
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 6, 2025
3fec326
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 6, 2025
5dffeea
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 6, 2025
904f52d
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 7, 2025
aad4b3b
minor test fixes and edge cases
john-wagster Jun 8, 2025
1f93921
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 8, 2025
968f539
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 9, 2025
1048f7f
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 9, 2025
490946f
Merge remote-tracking branch 'upstream/main' into ivf_hkmeans
benwtrent Jun 9, 2025
12a1207
fixing bugs
benwtrent Jun 9, 2025
8ca12bc
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 9, 2025
69a0b4e
merge
john-wagster Jun 9, 2025
ab5a61c
removed unnecessary int[]
john-wagster Jun 10, 2025
93ca452
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 10, 2025
f935144
removed null checking for ffvslice for now because it's extra cruft; …
john-wagster Jun 10, 2025
fc41d7d
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 10, 2025
ff0fad4
making constructor private to reduce confusion
john-wagster Jun 10, 2025
b48bfcf
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 10, 2025
e05ac74
Merge branch 'main' into ivf_hkmeans
john-wagster Jun 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.codec.vectors;

record CentroidAssignments(int numCentroids, float[][] cachedCentroids, short[] assignments, short[] soarAssignments) {

CentroidAssignments(float[][] centroids, short[] assignments, short[] soarAssignments) {
this(centroids.length, centroids, assignments, soarAssignments);
}

CentroidAssignments(int numCentroids, short[] assignments, short[] soarAssignments) {
this(numCentroids, null, assignments, soarAssignments);
}
}

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.SuppressForbidden;

import java.io.IOException;
import java.io.UncheckedIOException;
Expand Down Expand Up @@ -122,38 +121,34 @@ public final KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOExc
return rawVectorDelegate;
}

protected abstract int calculateAndWriteCentroids(
abstract CentroidAssignments calculateAndWriteCentroids(
FieldInfo fieldInfo,
FloatVectorValues floatVectorValues,
IndexOutput temporaryCentroidOutput,
IndexOutput centroidOutput,
MergeState mergeState,
float[] globalCentroid
) throws IOException;

abstract long[] buildAndWritePostingsLists(
FieldInfo fieldInfo,
CentroidAssignmentScorer scorer,
FloatVectorValues floatVectorValues,
IndexOutput postingsOutput,
MergeState mergeState
) throws IOException;

abstract CentroidAssignmentScorer calculateAndWriteCentroids(
abstract CentroidAssignments calculateAndWriteCentroids(
FieldInfo fieldInfo,
FloatVectorValues floatVectorValues,
IndexOutput centroidOutput,
InfoStream infoStream,
float[] globalCentroid
) throws IOException;

abstract long[] buildAndWritePostingsLists(
FieldInfo fieldInfo,
InfoStream infoStream,
CentroidAssignmentScorer scorer,
CentroidSupplier centroidSupplier,
FloatVectorValues floatVectorValues,
IndexOutput postingsOutput
IndexOutput postingsOutput,
InfoStream infoStream,
CentroidAssignments centroidAssignments
) throws IOException;

abstract CentroidAssignmentScorer createCentroidScorer(
abstract CentroidSupplier createCentroidSupplier(float[][] cachedCentroids) throws IOException;

abstract CentroidSupplier createCentroidSupplier(
IndexInput centroidsInput,
int numCentroids,
FieldInfo fieldInfo,
Expand All @@ -165,33 +160,31 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
rawVectorDelegate.flush(maxDoc, sortMap);
for (FieldWriter fieldWriter : fieldWriters) {
float[] globalCentroid = new float[fieldWriter.fieldInfo.getVectorDimension()];
// calculate global centroid
for (var vector : fieldWriter.delegate.getVectors()) {
for (int i = 0; i < globalCentroid.length; i++) {
globalCentroid[i] += vector[i];
}
}
for (int i = 0; i < globalCentroid.length; i++) {
globalCentroid[i] /= fieldWriter.delegate.getVectors().size();
}
// build a float vector values with random access
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldWriter.fieldInfo, fieldWriter.delegate, maxDoc);
// build centroids
long centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES);
final CentroidAssignmentScorer centroidAssignmentScorer = calculateAndWriteCentroids(

final CentroidAssignments centroidAssignments = calculateAndWriteCentroids(
fieldWriter.fieldInfo,
floatVectorValues,
ivfCentroids,
segmentWriteState.infoStream,
globalCentroid
);

CentroidSupplier centroidSupplier = createCentroidSupplier(centroidAssignments.cachedCentroids());

long centroidLength = ivfCentroids.getFilePointer() - centroidOffset;
final long[] offsets = buildAndWritePostingsLists(
fieldWriter.fieldInfo,
segmentWriteState.infoStream,
centroidAssignmentScorer,
centroidSupplier,
floatVectorValues,
ivfClusters
ivfClusters,
segmentWriteState.infoStream,
centroidAssignments
);
// write posting lists
writeMeta(fieldWriter.fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid);
}
}
Expand Down Expand Up @@ -250,7 +243,6 @@ static IVFVectorsReader getIVFReader(KnnVectorsReader vectorsReader, String fiel
}

@Override
@SuppressForbidden(reason = "require usage of Lucene's IOUtils#deleteFilesIgnoringExceptions(...)")
public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
Expand All @@ -276,26 +268,29 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro
float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()];
final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, in, numVectors);
success = false;
CentroidAssignmentScorer centroidAssignmentScorer;
long centroidOffset;
long centroidLength;
String centroidTempName = null;
int numCentroids;
IndexOutput centroidTemp = null;
CentroidAssignments centroidAssignments;
try {
centroidTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "civf_", IOContext.DEFAULT);
centroidTempName = centroidTemp.getName();
numCentroids = calculateAndWriteCentroids(

centroidAssignments = calculateAndWriteCentroids(
fieldInfo,
floatVectorValues,
centroidTemp,
mergeState,
calculatedGlobalCentroid
);
numCentroids = centroidAssignments.numCentroids();

success = true;
} finally {
if (success == false && centroidTempName != null) {
IOUtils.closeWhileHandlingException(centroidTemp);
org.apache.lucene.util.IOUtils.closeWhileHandlingException(centroidTemp);
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, centroidTempName);
}
}
Expand All @@ -304,27 +299,34 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro
centroidOffset = ivfCentroids.getFilePointer();
writeMeta(fieldInfo, centroidOffset, 0, new long[0], null);
CodecUtil.writeFooter(centroidTemp);
IOUtils.close(centroidTemp);
org.apache.lucene.util.IOUtils.close(centroidTemp);
return;
}
CodecUtil.writeFooter(centroidTemp);
IOUtils.close(centroidTemp);
org.apache.lucene.util.IOUtils.close(centroidTemp);
centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES);
try (IndexInput centroidInput = mergeState.segmentInfo.dir.openInput(centroidTempName, IOContext.DEFAULT)) {
ivfCentroids.copyBytes(centroidInput, centroidInput.length() - CodecUtil.footerLength());
try (IndexInput centroidsInput = mergeState.segmentInfo.dir.openInput(centroidTempName, IOContext.DEFAULT)) {
ivfCentroids.copyBytes(centroidsInput, centroidsInput.length() - CodecUtil.footerLength());
centroidLength = ivfCentroids.getFilePointer() - centroidOffset;
centroidAssignmentScorer = createCentroidScorer(centroidInput, numCentroids, fieldInfo, calculatedGlobalCentroid);
assert centroidAssignmentScorer.size() == numCentroids;

CentroidSupplier centroidSupplier = createCentroidSupplier(
centroidsInput,
numCentroids,
fieldInfo,
calculatedGlobalCentroid
);

// build a float vector values with random access
// build centroids
final long[] offsets = buildAndWritePostingsLists(
fieldInfo,
centroidAssignmentScorer,
centroidSupplier,
floatVectorValues,
ivfClusters,
mergeState
mergeState.infoStream,
centroidAssignments
);
assert offsets.length == centroidAssignmentScorer.size();
assert offsets.length == centroidSupplier.size();
writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, calculatedGlobalCentroid);
}
} finally {
Expand Down Expand Up @@ -452,8 +454,8 @@ public final long ramBytesUsed() {

private record FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter<float[]> delegate) {}

interface CentroidAssignmentScorer {
CentroidAssignmentScorer EMPTY = new CentroidAssignmentScorer() {
interface CentroidSupplier {
CentroidSupplier EMPTY = new CentroidSupplier() {
@Override
public int size() {
return 0;
Expand All @@ -463,24 +465,10 @@ public int size() {
public float[] centroid(int centroidOrdinal) {
throw new IllegalStateException("No centroids");
}

@Override
public float score(int centroidOrdinal) {
throw new IllegalStateException("No centroids");
}

@Override
public void setScoringVector(float[] vector) {
throw new IllegalStateException("No centroids");
}
};

int size();

float[] centroid(int centroidOrdinal) throws IOException;

void setScoringVector(float[] vector);

float score(int centroidOrdinal) throws IOException;
}
}
Loading