Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
8727f3a
wip on sorting centroids and assignments
john-wagster Jun 12, 2025
5741a0d
wip on sorting centroids and assignments
john-wagster Jun 13, 2025
0cc1f67
Merge branch 'main' into ivf_hkmeans_struc2
john-wagster Jun 17, 2025
188446a
got everything working e-to-e w single file centroids, still needs se…
john-wagster Jun 21, 2025
ae53b11
merging
john-wagster Jun 23, 2025
0dab43d
merging w new bulk score api
john-wagster Jun 23, 2025
fc73af3
bug fixes and cleanup
john-wagster Jun 24, 2025
470245b
cleanup
john-wagster Jun 24, 2025
8d25046
iter
john-wagster Jun 24, 2025
3e88c51
iter
john-wagster Jun 24, 2025
b089f8d
Merge branch 'main' into ivf_hkmeans_struc2
john-wagster Jun 24, 2025
6940131
Merge branch 'main' into ivf_hkmeans_struc2
john-wagster Jun 24, 2025
839f0e9
Merge branch 'ivf_hkmeans_struc2' of github.com:john-wagster/elastics…
john-wagster Jun 24, 2025
fff6650
Merge branch 'main' into ivf_hkmeans_struc2
john-wagster Jun 25, 2025
a718197
added exploration for a percentage of parents
john-wagster Jun 26, 2025
c333dcf
Merge branch 'ivf_hkmeans_struc2' of github.com:john-wagster/elastics…
john-wagster Jun 26, 2025
ce16153
Merge branch 'main' into ivf_hkmeans_struc2
john-wagster Jun 26, 2025
13a0007
[CI] Auto commit changes from spotless
Jun 26, 2025
9c3c0fb
Merge branch 'main' into ivf_hkmeans_struc2
john-wagster Jun 26, 2025
618d1ab
minor cleanup
john-wagster Jun 26, 2025
33b778f
Merge branch 'ivf_hkmeans_struc2' of github.com:john-wagster/elastics…
john-wagster Jun 26, 2025
325a21e
Merge branch 'main' into ivf_hkmeans_struc2
john-wagster Jun 26, 2025
42d2135
fix for small data usecase
john-wagster Jun 26, 2025
2e04623
Merge branch 'main' into ivf_hkmeans_struc2
john-wagster Jun 26, 2025
2f5bfd3
merge
john-wagster Jun 27, 2025
5910259
iterated on better mechanism for utilizing parent centroids
john-wagster Jun 29, 2025
ef9402c
scores not distances + added some diagnostics to be removed subsequently
john-wagster Jun 30, 2025
7d5bb39
merge
john-wagster Jun 30, 2025
77f020c
[CI] Auto commit changes from spotless
Jun 30, 2025
052f431
merge
john-wagster Jul 1, 2025
d7f10ea
Merge branch 'ivf_hkmeans_struc2' of github.com:john-wagster/elastics…
john-wagster Jul 1, 2025
89e5699
Merge branch 'main' into ivf_hkmeans_struc2
john-wagster Jul 1, 2025
ddaf610
merge
john-wagster Jul 2, 2025
ed48801
Merge branch 'main' into ivf_hkmeans_struc2
john-wagster Jul 3, 2025
ffe2929
Merge branch 'main' into ivf_hkmeans_struc2
john-wagster Jul 5, 2025
3a2ba61
using full hkmeans to gen parents
john-wagster Jul 7, 2025
7277b0a
[CI] Auto commit changes from spotless
Jul 7, 2025
60ef44d
merging
john-wagster Jul 7, 2025
866ac5a
Merge branch 'ivf_hkmeans_struc2' of github.com:john-wagster/elastics…
john-wagster Jul 7, 2025
402c767
cleanup
john-wagster Jul 7, 2025
0260eac
assert fix and minor cleanup
john-wagster Jul 7, 2025
3b0764e
[CI] Auto commit changes from spotless
Jul 7, 2025
5aa2682
fixed 1 off error and other cleanup
john-wagster Jul 8, 2025
1b4f69c
introducing soar logic at the parent centroid level
john-wagster Jul 11, 2025
234a0a8
bad assert and hardcoding to 8 for target size for now for parent cen…
john-wagster Jul 11, 2025
747150e
Merge branch 'main' into ivf_hkmeans_struc2
john-wagster Jul 14, 2025
c98ea03
Merge branch 'main' into ivf_hkmeans_struc2
john-wagster Jul 14, 2025
09c7c0e
fixed duplicate centroid check
john-wagster Jul 15, 2025
a85c16f
clean up related to temp file
john-wagster Jul 16, 2025
697985a
Merge branch 'main' into ivf_hkmeans_struc2
john-wagster Jul 16, 2025
de16ea1
clean up
john-wagster Jul 16, 2025
658f8ed
merge w main
john-wagster Jul 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,52 @@ public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVect
super(state, rawVectorsReader);
}

private abstract static class BaseCentroidQueryScorer implements CentroidQueryScorer {

// TODO can we do this in off-heap blocks?
float int4QuantizedScore(
float qcDist,
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
int dims,
float[] targetCorrections,
int targetComponentSum,
float centroidDp,
VectorSimilarityFunction similarityFunction
) {
float ax = targetCorrections[0];
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE;
float ay = queryCorrections.lowerInterval();
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
float y1 = queryCorrections.quantizedComponentSum();
float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
if (similarityFunction == EUCLIDEAN) {
score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score;
return Math.max(1 / (1f + score), 0);
} else {
// For cosine and max inner product, we need to apply the additional correction, which is
// assumed to be the non-centered dot-product between the vector and the centroid
score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp;
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
return VectorUtil.scaleMaxInnerProductScore(score);
}
return Math.max((1f + score) / 2f, 0);
}
}
}

private abstract static class ChildCentroidQueryScorer extends BaseCentroidQueryScorer implements CentroidWClusterOffsetQueryScorer {}

private abstract static class ParentCentroidQueryScorer extends BaseCentroidQueryScorer implements CentroidWChildrenQueryScorer {}

@Override
CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery)
throws IOException {
ChildCentroidQueryScorer getChildCentroidScorer(
FieldInfo fieldInfo,
int numParentCentroids,
int numCentroids,
IndexInput centroids,
float[] targetQuery
) throws IOException {
final FieldEntry fieldEntry = fields.get(fieldInfo.number);
final float globalCentroidDp = fieldEntry.globalCentroidDp();
final OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
Expand All @@ -65,11 +108,16 @@ CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, Ind
quantized[i] = (byte) scratch[i];
}
final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension());
return new CentroidQueryScorer() {
return new ChildCentroidQueryScorer() {
int currentCentroid = -1;
private final float[] centroid = new float[fieldInfo.getVectorDimension()];
private final float[] centroidCorrectiveValues = new float[3];
private final long rawCentroidsOffset = (long) numCentroids * (fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES);
private int clusterOrdinal;
private final long quantizedVectorByteSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES;
private final long quantizedVectorNodeByteSize = quantizedVectorByteSize + Integer.BYTES;
private final long parentNodeByteSize = quantizedVectorByteSize + 2 * Integer.BYTES;
private final long quantizedCentroidsOffset = numParentCentroids * parentNodeByteSize;
private final long rawCentroidsOffset = numParentCentroids * parentNodeByteSize + numCentroids * quantizedVectorNodeByteSize;
private final long rawCentroidsByteSize = (long) Float.BYTES * fieldInfo.getVectorDimension();

@Override
Expand All @@ -87,18 +135,36 @@ public float[] centroid(int centroidOrdinal) throws IOException {
return centroid;
}

public void bulkScore(NeighborQueue queue) throws IOException {
@Override
public void bulkScore(NeighborQueue queue, int start, int end) throws IOException {
// TODO: bulk score centroids like we do with posting lists
centroids.seek(0L);
for (int i = 0; i < numCentroids; i++) {
assert start >= 0;
assert end > 0;
assert start + end <= numCentroids;
centroids.seek(quantizedCentroidsOffset + quantizedVectorNodeByteSize * start);
for (int i = start; i < end; i++) {
queue.add(i, score());
}
}

// TODO: this causes seeks refactor to move this to the end of the block in this file
@Override
public int getClusterOrdinal(int centroidOrdinal) throws IOException {
if (centroidOrdinal != currentCentroid) {
centroids.seek(quantizedCentroidsOffset + quantizedVectorNodeByteSize * centroidOrdinal + quantizedVectorByteSize);
clusterOrdinal = centroids.readInt();
}
return clusterOrdinal;
}

private float score() throws IOException {
final float qcDist = scorer.int4DotProduct(quantized);
centroids.readFloats(centroidCorrectiveValues, 0, 3);
final int quantizedCentroidComponentSum = Short.toUnsignedInt(centroids.readShort());

// TODO: should we consider a different format such as moving these to the beginning of the file to benefit bulk read
centroids.skipBytes(Integer.BYTES);

return int4QuantizedScore(
qcDist,
queryParams,
Expand All @@ -109,46 +175,125 @@ private float score() throws IOException {
fieldInfo.getVectorSimilarityFunction()
);
}
};
}

// TODO can we do this in off-heap blocks?
private float int4QuantizedScore(
float qcDist,
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
int dims,
float[] targetCorrections,
int targetComponentSum,
float centroidDp,
VectorSimilarityFunction similarityFunction
) {
float ax = targetCorrections[0];
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE;
float ay = queryCorrections.lowerInterval();
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
float y1 = queryCorrections.quantizedComponentSum();
float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
if (similarityFunction == EUCLIDEAN) {
score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score;
return Math.max(1 / (1f + score), 0);
} else {
// For cosine and max inner product, we need to apply the additional correction, which is
// assumed to be the non-centered dot-product between the vector and the centroid
score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp;
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
return VectorUtil.scaleMaxInnerProductScore(score);
}
return Math.max((1f + score) / 2f, 0);
@Override
ParentCentroidQueryScorer getParentCentroidScorer(
FieldInfo fieldInfo,
int numParentCentroids,
IndexInput centroids,
float[] targetQuery
) throws IOException {
FieldEntry fieldEntry = fields.get(fieldInfo.number);
float globalCentroidDp = fieldEntry.globalCentroidDp();
OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
final int[] scratch = new int[targetQuery.length];
final OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize(
ArrayUtil.copyArray(targetQuery),
scratch,
(byte) 4,
fieldEntry.globalCentroid()
);
final byte[] quantized = new byte[targetQuery.length];
for (int i = 0; i < quantized.length; i++) {
quantized[i] = (byte) scratch[i];
}
final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension());
return new ParentCentroidQueryScorer() {
int currentCentroid = -1;
private final float[] centroidCorrectiveValues = new float[3];
private final long quantizedVectorByteSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES;
private final long parentNodeByteSize = quantizedVectorByteSize + 2 * Integer.BYTES;

private int childCentroidStart;
private int childCount;

@Override
public int size() {
return numParentCentroids;
}

@Override
public float[] centroid(int centroidOrdinal) throws IOException {
throw new IllegalStateException("can't score at the parent level");
}

private void readChildDetails(int centroidOrdinal) throws IOException {
if (centroidOrdinal == currentCentroid) {
return;
}
centroids.seek(parentNodeByteSize * centroidOrdinal + quantizedVectorByteSize);
childCentroidStart = centroids.readInt();
childCount = centroids.readInt();
currentCentroid = centroidOrdinal;
}

@Override
public int getChildCentroidStart(int centroidOrdinal) throws IOException {
readChildDetails(centroidOrdinal);
return childCentroidStart;
}

@Override
public int getChildCount(int centroidOrdinal) throws IOException {
readChildDetails(centroidOrdinal);
return childCount;
}

@Override
public void bulkScore(NeighborQueue queue, int start, int end) throws IOException {
assert start > 0;
assert end > 0;
assert start + end <= numParentCentroids;
// TODO: bulk score centroids like we do with posting lists
centroids.seek(parentNodeByteSize * start);
for (int i = start; i < end; i++) {
queue.add(i, score());
}
}

private float score() throws IOException {
final float qcDist = scorer.int4DotProduct(quantized);
centroids.readFloats(centroidCorrectiveValues, 0, 3);
final int quantizedCentroidComponentSum = Short.toUnsignedInt(centroids.readShort());

// TODO: should we consider a different format such as moving these to the beginning of the file to benefit bulk read
// TODO: cache these at this point when scoring since we'll likely read many of them?
// child partition start, child partition count
centroids.skipBytes(Integer.BYTES * 2);

return int4QuantizedScore(
qcDist,
queryParams,
fieldInfo.getVectorDimension(),
centroidCorrectiveValues,
quantizedCentroidComponentSum,
globalCentroidDp,
fieldInfo.getVectorSimilarityFunction()
);
}
};
}

@Override
NeighborQueue scorePostingLists(
FieldInfo fieldInfo,
KnnCollector knnCollector,
CentroidQueryScorer centroidQueryScorer,
int nProbe,
int start,
int count
) throws IOException {
NeighborQueue neighborQueue = new NeighborQueue(count, true);
centroidQueryScorer.bulkScore(neighborQueue, start, start + count);
return neighborQueue;
}

@Override
NeighborQueue scorePostingLists(FieldInfo fieldInfo, KnnCollector knnCollector, CentroidQueryScorer centroidQueryScorer, int nProbe)
throws IOException {
NeighborQueue neighborQueue = new NeighborQueue(centroidQueryScorer.size(), true);
centroidQueryScorer.bulkScore(neighborQueue);
return neighborQueue;
return scorePostingLists(fieldInfo, knnCollector, centroidQueryScorer, nProbe, 0, centroidQueryScorer.size());
}

@Override
Expand Down
Loading