Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.benchmark.vector;

import org.elasticsearch.common.logging.LogConfigurator;
import org.elasticsearch.index.codec.vectors.cluster.NeighborHood;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;

import java.io.IOException;
import java.util.Random;
import java.util.concurrent.TimeUnit;

@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.SECONDS)
@State(Scope.Benchmark)
// first iteration is complete garbage, so make sure we really warmup
@Warmup(iterations = 1, time = 1)
// real iterations. not useful to spend tons of time here, better to fork more
@Measurement(iterations = 3, time = 1)
// engage some noise reduction
@Fork(value = 1)
public class ComputeNeighboursBenchmark {

static {
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
}

@Param({ "1000", "2000", "3000", "5000", "10000", "20000", "50000" })
int numVectors;

@Param({ "384", "782", "1024" })
int dims;

float[][] vectors;
int clusterPerNeighbour = 128;

@Setup
public void setup() throws IOException {
Random random = new Random(123);
vectors = new float[numVectors][dims];
for (float[] vector : vectors) {
for (int i = 0; i < dims; i++) {
vector[i] = random.nextFloat();
}
}
}

@Benchmark
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public void bruteForce(Blackhole bh) {
bh.consume(NeighborHood.computeNeighborhoodsBruteForce(vectors, clusterPerNeighbour));
}

@Benchmark
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public void graph(Blackhole bh) throws IOException {
bh.consume(NeighborHood.computeNeighborhoodsGraph(vectors, clusterPerNeighbour));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,40 +139,40 @@ private static int getBestCentroidFromNeighbours(
NeighborHood neighborhood,
float[] distances
) {
final int limit = neighborhood.neighbors.length - 3;
final int limit = neighborhood.neighbors().length - 3;
int bestCentroidOffset = centroidIdx;
assert centroidIdx >= 0 && centroidIdx < centroids.length;
float minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]);
int i = 0;
for (; i < limit; i += 4) {
if (minDsq < neighborhood.maxIntraDistance) {
if (minDsq < neighborhood.maxIntraDistance()) {
// if the distance found is smaller than the maximum intra-cluster distance
// we don't consider it for further re-assignment
return bestCentroidOffset;
}
ESVectorUtil.squareDistanceBulk(
vector,
centroids[neighborhood.neighbors[i]],
centroids[neighborhood.neighbors[i + 1]],
centroids[neighborhood.neighbors[i + 2]],
centroids[neighborhood.neighbors[i + 3]],
centroids[neighborhood.neighbors()[i]],
centroids[neighborhood.neighbors()[i + 1]],
centroids[neighborhood.neighbors()[i + 2]],
centroids[neighborhood.neighbors()[i + 3]],
distances
);
for (int j = 0; j < distances.length; j++) {
float dsq = distances[j];
if (dsq < minDsq) {
minDsq = dsq;
bestCentroidOffset = neighborhood.neighbors[i + j];
bestCentroidOffset = neighborhood.neighbors()[i + j];
}
}
}
for (; i < neighborhood.neighbors.length; i++) {
if (minDsq < neighborhood.maxIntraDistance) {
for (; i < neighborhood.neighbors().length; i++) {
if (minDsq < neighborhood.maxIntraDistance()) {
// if the distance found is smaller than the maximum intra-cluster distance
// we don't consider it for further re-assignment
return bestCentroidOffset;
}
int offset = neighborhood.neighbors[i];
int offset = neighborhood.neighbors()[i];
// float score = neighborhood.scores[i];
assert offset >= 0 && offset < centroids.length : "Invalid neighbor offset: " + offset;
// compute the distance to the centroid
Expand Down Expand Up @@ -210,50 +210,14 @@ private static int getBestCentroid(float[][] centroids, float[] vector, float[]
return bestCentroidOffset;
}

private NeighborHood[] computeNeighborhoods(float[][] centers, int clustersPerNeighborhood) {
int k = centers.length;
assert k > clustersPerNeighborhood;
NeighborQueue[] neighborQueues = new NeighborQueue[k];
for (int i = 0; i < k; i++) {
neighborQueues[i] = new NeighborQueue(clustersPerNeighborhood, true);
private NeighborHood[] computeNeighborhoods(float[][] centers, int clustersPerNeighborhood) throws IOException {
assert centers.length > clustersPerNeighborhood;
// experiments shows that below 15k, we better use brute force, otherwise hnsw gives us a nice speed up
if (centers.length < 15_000) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can optimise the graph to work better for lower scale but this is good as a first threshold. That's for segments greater than 1M with 64 vectors per centroid.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reducing the number of connections could make this threshold smaller.

Copy link
Contributor Author

@iverase iverase Sep 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. I didn't spend too much because time it seems pretty fast for low values (few seconds) so I wonder if there is need to optimize those cases.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think just picking something "good enough" is alright. It provides a nice improvement and any optimizations we make won't be "format breaking" :)

return NeighborHood.computeNeighborhoodsBruteForce(centers, clustersPerNeighborhood);
} else {
return NeighborHood.computeNeighborhoodsGraph(centers, clustersPerNeighborhood);
}
final float[] scores = new float[4];
final int limit = k - 3;
for (int i = 0; i < k - 1; i++) {
float[] center = centers[i];
int j = i + 1;
for (; j < limit; j += 4) {
ESVectorUtil.squareDistanceBulk(center, centers[j], centers[j + 1], centers[j + 2], centers[j + 3], scores);
for (int h = 0; h < 4; h++) {
neighborQueues[j + h].insertWithOverflow(i, scores[h]);
neighborQueues[i].insertWithOverflow(j + h, scores[h]);
}
}
for (; j < k; j++) {
float dsq = VectorUtil.squareDistance(center, centers[j]);
neighborQueues[j].insertWithOverflow(i, dsq);
neighborQueues[i].insertWithOverflow(j, dsq);
}
}

NeighborHood[] neighborhoods = new NeighborHood[k];
for (int i = 0; i < k; i++) {
NeighborQueue queue = neighborQueues[i];
if (queue.size() == 0) {
// no neighbors, skip
neighborhoods[i] = NeighborHood.EMPTY;
continue;
}
// consume the queue into the neighbors array and get the maximum intra-cluster distance
int[] neighbors = new int[queue.size()];
float maxIntraDistance = queue.topScore();
int iter = 0;
while (queue.size() > 0) {
neighbors[neighbors.length - ++iter] = queue.pop();
}
neighborhoods[i] = new NeighborHood(neighbors, maxIntraDistance);
}
return neighborhoods;
}

private void assignSpilled(
Expand Down Expand Up @@ -299,8 +263,8 @@ private void assignSpilled(
if (neighborhoods != null) {
assert neighborhoods[currAssignment] != null;
NeighborHood neighborhood = neighborhoods[currAssignment];
centroidCount = neighborhood.neighbors.length;
centroidOrds = c -> neighborhood.neighbors[c];
centroidCount = neighborhood.neighbors().length;
centroidOrds = c -> neighborhood.neighbors()[c];
} else {
centroidCount = centroids.length - 1;
centroidOrds = c -> c < currAssignment ? c : c + 1; // skip the current centroid
Expand Down Expand Up @@ -344,10 +308,6 @@ private void assignSpilled(
}
}

record NeighborHood(int[] neighbors, float maxIntraDistance) {
static final NeighborHood EMPTY = new NeighborHood(new int[0], Float.POSITIVE_INFINITY);
}

/**
* cluster using a lloyd k-means algorithm that is not neighbor aware
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.codec.vectors.cluster;

import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.elasticsearch.simdvec.ESVectorUtil;

import java.io.IOException;

public record NeighborHood(int[] neighbors, float maxIntraDistance) {

static final NeighborHood EMPTY = new NeighborHood(new int[0], Float.POSITIVE_INFINITY);

public static NeighborHood[] computeNeighborhoodsGraph(float[][] centers, int clustersPerNeighborhood) throws IOException {
final UpdateableRandomVectorScorer scorer = new UpdateableRandomVectorScorer() {
int scoringOrdinal;

@Override
public float score(int node) {
return VectorSimilarityFunction.EUCLIDEAN.compare(centers[scoringOrdinal], centers[node]);
}

@Override
public int maxOrd() {
return centers.length;
}

@Override
public void setScoringOrdinal(int node) {
scoringOrdinal = node;
}
};
final RandomVectorScorerSupplier supplier = new RandomVectorScorerSupplier() {
@Override
public UpdateableRandomVectorScorer scorer() {
return scorer;
}

@Override
public RandomVectorScorerSupplier copy() {
return this;
}
};
final OnHeapHnswGraph graph = HnswGraphBuilder.create(supplier, 16, 100, 42L).build(centers.length);
final NeighborHood[] neighborhoods = new NeighborHood[centers.length];
final SingleBit singleBit = new SingleBit(centers.length);
for (int i = 0; i < centers.length; i++) {
scorer.setScoringOrdinal(i);
singleBit.indexSet = i;
final KnnCollector collector = HnswGraphSearcher.search(scorer, clustersPerNeighborhood, graph, singleBit, Integer.MAX_VALUE);
final ScoreDoc[] scoreDocs = collector.topDocs().scoreDocs;
if (scoreDocs.length == 0) {
// no neighbors, skip
neighborhoods[i] = NeighborHood.EMPTY;
continue;
}
final int[] neighbors = new int[scoreDocs.length];
for (int j = 0; j < neighbors.length; j++) {
neighbors[j] = scoreDocs[j].doc;
assert neighbors[j] != i;
}
final float minCompetitiveSimilarity = (1f / scoreDocs[neighbors.length - 1].score) - 1;
neighborhoods[i] = new NeighborHood(neighbors, minCompetitiveSimilarity);
}
return neighborhoods;
}

private static class SingleBit implements Bits {

private final int length;
private int indexSet;

SingleBit(int length) {
this.length = length;
}

@Override
public boolean get(int index) {
return index != indexSet;
}

@Override
public int length() {
return length;
}
}

public static NeighborHood[] computeNeighborhoodsBruteForce(float[][] centers, int clustersPerNeighborhood) {
int k = centers.length;
NeighborQueue[] neighborQueues = new NeighborQueue[k];
for (int i = 0; i < k; i++) {
neighborQueues[i] = new NeighborQueue(clustersPerNeighborhood, true);
}
final float[] scores = new float[4];
final int limit = k - 3;
for (int i = 0; i < k - 1; i++) {
float[] center = centers[i];
int j = i + 1;
for (; j < limit; j += 4) {
ESVectorUtil.squareDistanceBulk(center, centers[j], centers[j + 1], centers[j + 2], centers[j + 3], scores);
for (int h = 0; h < 4; h++) {
neighborQueues[j + h].insertWithOverflow(i, scores[h]);
neighborQueues[i].insertWithOverflow(j + h, scores[h]);
}
}
for (; j < k; j++) {
float dsq = VectorUtil.squareDistance(center, centers[j]);
neighborQueues[j].insertWithOverflow(i, dsq);
neighborQueues[i].insertWithOverflow(j, dsq);
}
}

NeighborHood[] neighborhoods = new NeighborHood[k];
for (int i = 0; i < k; i++) {
NeighborQueue queue = neighborQueues[i];
if (queue.size() == 0) {
// no neighbors, skip
neighborhoods[i] = NeighborHood.EMPTY;
continue;
}
// consume the queue into the neighbors array and get the maximum intra-cluster distance
int[] neighbors = new int[queue.size()];
float maxIntraDistance = queue.topScore();
int iter = 0;
while (queue.size() > 0) {
neighbors[neighbors.length - ++iter] = queue.pop();
}
neighborhoods[i] = new NeighborHood(neighbors, maxIntraDistance);
}
return neighborhoods;
}
}
Loading