diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Int7uScorerBenchmark.java similarity index 86% rename from benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java rename to benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Int7uScorerBenchmark.java index f56bb8995b34e..9fc0b7dc199ac 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/VectorScorerBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/Int7uScorerBenchmark.java @@ -55,24 +55,23 @@ /** * Benchmark that compares various scalar quantized vector similarity function * implementations;: scalar, lucene's panama-ized, and Elasticsearch's native. - * Run with ./gradlew -p benchmarks run --args 'VectorScorerBenchmark' + * Run with ./gradlew -p benchmarks run --args 'Int7uScorerBenchmark' */ -public class VectorScorerBenchmark { +public class Int7uScorerBenchmark { static { LogConfigurator.configureESLogging(); // native access requires logging to be initialized } @Param({ "96", "768", "1024" }) - int dims; - int size = 2; // there are only two vectors to compare + public int dims; + final int size = 2; // there are only two vectors to compare Directory dir; IndexInput in; VectorScorerFactory factory; - byte[] vec1; - byte[] vec2; + byte[] vec1, vec2; float vec1Offset; float vec2Offset; float scoreCorrectionConstant; @@ -139,39 +138,6 @@ public void setup() throws IOException { nativeDotScorerQuery = factory.getInt7SQVectorScorer(VectorSimilarityFunction.DOT_PRODUCT, values, queryVec).get(); luceneSqrScorerQuery = luceneScorer(values, VectorSimilarityFunction.EUCLIDEAN, queryVec); nativeSqrScorerQuery = factory.getInt7SQVectorScorer(VectorSimilarityFunction.EUCLIDEAN, values, queryVec).get(); - - // sanity - var f1 = dotProductLucene(); - var f2 = dotProductNative(); - var f3 = dotProductScalar(); - if (f1 != f2) { - throw new AssertionError("lucene[" + f1 + "] != " + "native[" + f2 + "]"); - } - if (f1 != f3) { - throw new AssertionError("lucene[" + f1 + "] != " + "scalar[" + f3 + "]"); - } - // square distance - f1 = squareDistanceLucene(); - f2 = squareDistanceNative(); - f3 = squareDistanceScalar(); - if (f1 != f2) { - throw new AssertionError("lucene[" + f1 + "] != " + "native[" + f2 + "]"); - } - if (f1 != f3) { - throw new AssertionError("lucene[" + f1 + "] != " + "scalar[" + f3 + "]"); - } - - var q1 = dotProductLuceneQuery(); - var q2 = dotProductNativeQuery(); - if (q1 != q2) { - throw new AssertionError("query: lucene[" + q1 + "] != " + "native[" + q2 + "]"); - } - - var sqr1 = squareDistanceLuceneQuery(); - var sqr2 = squareDistanceNativeQuery(); - if (sqr1 != sqr2) { - throw new AssertionError("query: lucene[" + q1 + "] != " + "native[" + q2 + "]"); - } } @TearDown diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/JDKVectorInt7uBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/JDKVectorInt7uBenchmark.java index 41c2b3192cc92..e09f37ef24086 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/JDKVectorInt7uBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/JDKVectorInt7uBenchmark.java @@ -52,7 +52,7 @@ public class JDKVectorInt7uBenchmark { Arena arena; - @Param({ "1", "128", "207", "256", "300", "512", "702", "1024" }) + @Param({ "1", "128", "207", "256", "300", "512", "702", "1024", "1536", "2048" }) public int size; @Setup(Level.Iteration) diff --git a/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/Int7uScorerBenchmarkTests.java b/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/Int7uScorerBenchmarkTests.java new file mode 100644 index 0000000000000..b8cf689783d1d --- /dev/null +++ b/benchmarks/src/test/java/org/elasticsearch/benchmark/vector/Int7uScorerBenchmarkTests.java @@ -0,0 +1,80 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.benchmark.vector; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.util.Constants; +import org.elasticsearch.test.ESTestCase; +import org.junit.BeforeClass; +import org.openjdk.jmh.annotations.Param; + +import java.util.Arrays; + +public class Int7uScorerBenchmarkTests extends ESTestCase { + + final double delta = 1e-3; + final int dims; + + public Int7uScorerBenchmarkTests(int dims) { + this.dims = dims; + } + + @BeforeClass + public static void skipWindows() { + assumeFalse("doesn't work on windows yet", Constants.WINDOWS); + } + + public void testDotProduct() throws Exception { + for (int i = 0; i < 100; i++) { + var bench = new Int7uScorerBenchmark(); + bench.dims = dims; + bench.setup(); + try { + float expected = bench.dotProductScalar(); + assertEquals(expected, bench.dotProductLucene(), delta); + assertEquals(expected, bench.dotProductNative(), delta); + + expected = bench.dotProductLuceneQuery(); + assertEquals(expected, bench.dotProductNativeQuery(), delta); + } finally { + bench.teardown(); + } + } + } + + public void testSquareDistance() throws Exception { + for (int i = 0; i < 100; i++) { + var bench = new Int7uScorerBenchmark(); + bench.dims = dims; + bench.setup(); + try { + float expected = bench.squareDistanceScalar(); + assertEquals(expected, bench.squareDistanceLucene(), delta); + assertEquals(expected, bench.squareDistanceNative(), delta); + + expected = bench.squareDistanceLuceneQuery(); + assertEquals(expected, bench.squareDistanceNativeQuery(), delta); + } finally { + bench.teardown(); + } + } + } + + @ParametersFactory + public static Iterable parametersFactory() { + try { + var params = Int7uScorerBenchmark.class.getField("dims").getAnnotationsByType(Param.class)[0].value(); + return () -> Arrays.stream(params).map(Integer::parseInt).map(i -> new Object[] { i }).iterator(); + } catch (NoSuchFieldException e) { + throw new AssertionError(e); + } + } +}