Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.benchmark.vector;

import org.elasticsearch.common.logging.LogConfigurator;
import org.elasticsearch.index.codec.vectors.cluster.NeighborHood;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;

import java.io.IOException;
import java.util.Random;
import java.util.concurrent.TimeUnit;

@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.SECONDS)
@State(Scope.Benchmark)
// first iteration is complete garbage, so make sure we really warmup
@Warmup(iterations = 1, time = 1)
// real iterations. not useful to spend tons of time here, better to fork more
@Measurement(iterations = 3, time = 1)
// engage some noise reduction
@Fork(value = 1)
public class ComputeNeighboursBenchmark {

static {
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
}

@Param({ "1000", "2000", "3000", "5000", "10000", "20000", "50000" })
int numVectors;

@Param({ "384", "782", "1024" })
int dims;

float[][] vectors;
int clusterPerNeighbour = 128;

@Setup
public void setup() throws IOException {
Random random = new Random(123);
vectors = new float[numVectors][dims];
for (float[] vector : vectors) {
for (int i = 0; i < dims; i++) {
vector[i] = random.nextFloat();
}
}
}

@Benchmark
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public void bruteForce(Blackhole bh) {
bh.consume(NeighborHood.computeNeighborhoodsBruteForce(vectors, clusterPerNeighbour));
}

@Benchmark
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public void graph(Blackhole bh) throws IOException {
bh.consume(NeighborHood.computeNeighborhoodsGraph(vectors, clusterPerNeighbour));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,40 +139,40 @@ private static int getBestCentroidFromNeighbours(
NeighborHood neighborhood,
float[] distances
) {
final int limit = neighborhood.neighbors.length - 3;
final int limit = neighborhood.neighbors().length - 3;
int bestCentroidOffset = centroidIdx;
assert centroidIdx >= 0 && centroidIdx < centroids.length;
float minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]);
int i = 0;
for (; i < limit; i += 4) {
if (minDsq < neighborhood.maxIntraDistance) {
if (minDsq < neighborhood.maxIntraDistance()) {
// if the distance found is smaller than the maximum intra-cluster distance
// we don't consider it for further re-assignment
return bestCentroidOffset;
}
ESVectorUtil.squareDistanceBulk(
vector,
centroids[neighborhood.neighbors[i]],
centroids[neighborhood.neighbors[i + 1]],
centroids[neighborhood.neighbors[i + 2]],
centroids[neighborhood.neighbors[i + 3]],
centroids[neighborhood.neighbors()[i]],
centroids[neighborhood.neighbors()[i + 1]],
centroids[neighborhood.neighbors()[i + 2]],
centroids[neighborhood.neighbors()[i + 3]],
distances
);
for (int j = 0; j < distances.length; j++) {
float dsq = distances[j];
if (dsq < minDsq) {
minDsq = dsq;
bestCentroidOffset = neighborhood.neighbors[i + j];
bestCentroidOffset = neighborhood.neighbors()[i + j];
}
}
}
for (; i < neighborhood.neighbors.length; i++) {
if (minDsq < neighborhood.maxIntraDistance) {
for (; i < neighborhood.neighbors().length; i++) {
if (minDsq < neighborhood.maxIntraDistance()) {
// if the distance found is smaller than the maximum intra-cluster distance
// we don't consider it for further re-assignment
return bestCentroidOffset;
}
int offset = neighborhood.neighbors[i];
int offset = neighborhood.neighbors()[i];
// float score = neighborhood.scores[i];
assert offset >= 0 && offset < centroids.length : "Invalid neighbor offset: " + offset;
// compute the distance to the centroid
Expand Down Expand Up @@ -210,52 +210,6 @@ private static int getBestCentroid(float[][] centroids, float[] vector, float[]
return bestCentroidOffset;
}

private NeighborHood[] computeNeighborhoods(float[][] centers, int clustersPerNeighborhood) {
int k = centers.length;
assert k > clustersPerNeighborhood;
NeighborQueue[] neighborQueues = new NeighborQueue[k];
for (int i = 0; i < k; i++) {
neighborQueues[i] = new NeighborQueue(clustersPerNeighborhood, true);
}
final float[] scores = new float[4];
final int limit = k - 3;
for (int i = 0; i < k - 1; i++) {
float[] center = centers[i];
int j = i + 1;
for (; j < limit; j += 4) {
ESVectorUtil.squareDistanceBulk(center, centers[j], centers[j + 1], centers[j + 2], centers[j + 3], scores);
for (int h = 0; h < 4; h++) {
neighborQueues[j + h].insertWithOverflow(i, scores[h]);
neighborQueues[i].insertWithOverflow(j + h, scores[h]);
}
}
for (; j < k; j++) {
float dsq = VectorUtil.squareDistance(center, centers[j]);
neighborQueues[j].insertWithOverflow(i, dsq);
neighborQueues[i].insertWithOverflow(j, dsq);
}
}

NeighborHood[] neighborhoods = new NeighborHood[k];
for (int i = 0; i < k; i++) {
NeighborQueue queue = neighborQueues[i];
if (queue.size() == 0) {
// no neighbors, skip
neighborhoods[i] = NeighborHood.EMPTY;
continue;
}
// consume the queue into the neighbors array and get the maximum intra-cluster distance
int[] neighbors = new int[queue.size()];
float maxIntraDistance = queue.topScore();
int iter = 0;
while (queue.size() > 0) {
neighbors[neighbors.length - ++iter] = queue.pop();
}
neighborhoods[i] = new NeighborHood(neighbors, maxIntraDistance);
}
return neighborhoods;
}

private void assignSpilled(
FloatVectorValues vectors,
KMeansIntermediate kmeansIntermediate,
Expand Down Expand Up @@ -299,8 +253,8 @@ private void assignSpilled(
if (neighborhoods != null) {
assert neighborhoods[currAssignment] != null;
NeighborHood neighborhood = neighborhoods[currAssignment];
centroidCount = neighborhood.neighbors.length;
centroidOrds = c -> neighborhood.neighbors[c];
centroidCount = neighborhood.neighbors().length;
centroidOrds = c -> neighborhood.neighbors()[c];
} else {
centroidCount = centroids.length - 1;
centroidOrds = c -> c < currAssignment ? c : c + 1; // skip the current centroid
Expand Down Expand Up @@ -344,10 +298,6 @@ private void assignSpilled(
}
}

record NeighborHood(int[] neighbors, float maxIntraDistance) {
static final NeighborHood EMPTY = new NeighborHood(new int[0], Float.POSITIVE_INFINITY);
}

/**
* cluster using a lloyd k-means algorithm that is not neighbor aware
*
Expand Down Expand Up @@ -390,7 +340,7 @@ private void doCluster(FloatVectorValues vectors, KMeansIntermediate kMeansInter
NeighborHood[] neighborhoods = null;
// if there are very few centroids, don't bother with neighborhoods or neighbor aware clustering
if (neighborAware && centroids.length > clustersPerNeighborhood) {
neighborhoods = computeNeighborhoods(centroids, clustersPerNeighborhood);
neighborhoods = NeighborHood.computeNeighborhoods(centroids, clustersPerNeighborhood);
}
cluster(vectors, kMeansIntermediate, neighborhoods);
if (neighborAware && soarLambda >= 0) {
Expand Down
Loading