Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,24 +55,23 @@
/**
* Benchmark that compares various scalar quantized vector similarity function
* implementations;: scalar, lucene's panama-ized, and Elasticsearch's native.
* Run with ./gradlew -p benchmarks run --args 'VectorScorerBenchmark'
* Run with ./gradlew -p benchmarks run --args 'Int7uScorerBenchmark'
*/
public class VectorScorerBenchmark {
public class Int7uScorerBenchmark {

static {
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
}

@Param({ "96", "768", "1024" })
int dims;
int size = 2; // there are only two vectors to compare
public int dims;
final int size = 2; // there are only two vectors to compare

Directory dir;
IndexInput in;
VectorScorerFactory factory;

byte[] vec1;
byte[] vec2;
byte[] vec1, vec2;
float vec1Offset;
float vec2Offset;
float scoreCorrectionConstant;
Expand Down Expand Up @@ -139,39 +138,6 @@ public void setup() throws IOException {
nativeDotScorerQuery = factory.getInt7SQVectorScorer(VectorSimilarityFunction.DOT_PRODUCT, values, queryVec).get();
luceneSqrScorerQuery = luceneScorer(values, VectorSimilarityFunction.EUCLIDEAN, queryVec);
nativeSqrScorerQuery = factory.getInt7SQVectorScorer(VectorSimilarityFunction.EUCLIDEAN, values, queryVec).get();

// sanity
var f1 = dotProductLucene();
var f2 = dotProductNative();
var f3 = dotProductScalar();
if (f1 != f2) {
throw new AssertionError("lucene[" + f1 + "] != " + "native[" + f2 + "]");
}
if (f1 != f3) {
throw new AssertionError("lucene[" + f1 + "] != " + "scalar[" + f3 + "]");
}
// square distance
f1 = squareDistanceLucene();
f2 = squareDistanceNative();
f3 = squareDistanceScalar();
if (f1 != f2) {
throw new AssertionError("lucene[" + f1 + "] != " + "native[" + f2 + "]");
}
if (f1 != f3) {
throw new AssertionError("lucene[" + f1 + "] != " + "scalar[" + f3 + "]");
}

var q1 = dotProductLuceneQuery();
var q2 = dotProductNativeQuery();
if (q1 != q2) {
throw new AssertionError("query: lucene[" + q1 + "] != " + "native[" + q2 + "]");
}

var sqr1 = squareDistanceLuceneQuery();
var sqr2 = squareDistanceNativeQuery();
if (sqr1 != sqr2) {
throw new AssertionError("query: lucene[" + q1 + "] != " + "native[" + q2 + "]");
}
}

@TearDown
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.benchmark.vector;

import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;

import org.apache.lucene.util.Constants;
import org.elasticsearch.test.ESTestCase;
import org.junit.BeforeClass;
import org.openjdk.jmh.annotations.Param;

import java.util.Arrays;

public class Int7uScorerBenchmarkTests extends ESTestCase {

final double delta = 1e-3;
final int dims;

public Int7uScorerBenchmarkTests(int dims) {
this.dims = dims;
}

@BeforeClass
public static void skipWindows() {
assumeFalse("doesn't work on windows yet", Constants.WINDOWS);
}

public void testDotProduct() throws Exception {
for (int i = 0; i < 100; i++) {
var bench = new Int7uScorerBenchmark();
bench.dims = dims;
bench.setup();
try {
float expected = bench.dotProductScalar();
assertEquals(expected, bench.dotProductLucene(), delta);
assertEquals(expected, bench.dotProductNative(), delta);

expected = bench.dotProductLuceneQuery();
assertEquals(expected, bench.dotProductNativeQuery(), delta);
} finally {
bench.teardown();
}
}
}

public void testSquareDistance() throws Exception {
for (int i = 0; i < 100; i++) {
var bench = new Int7uScorerBenchmark();
bench.dims = dims;
bench.setup();
try {
float expected = bench.squareDistanceScalar();
assertEquals(expected, bench.squareDistanceLucene(), delta);
assertEquals(expected, bench.squareDistanceNative(), delta);

expected = bench.squareDistanceLuceneQuery();
assertEquals(expected, bench.squareDistanceNativeQuery(), delta);
} finally {
bench.teardown();
}
}
}

@ParametersFactory
public static Iterable<Object[]> parametersFactory() {
try {
var params = Int7uScorerBenchmark.class.getField("dims").getAnnotationsByType(Param.class)[0].value();
return () -> Arrays.stream(params).map(Integer::parseInt).map(i -> new Object[] { i }).iterator();
} catch (NoSuchFieldException e) {
throw new AssertionError(e);
}
}
}