diff --git a/libs/native/src/test/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctionsTests.java b/libs/native/src/test/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctionsTests.java index 53fd6c7f1fa6b..3d8433bf36487 100644 --- a/libs/native/src/test/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctionsTests.java +++ b/libs/native/src/test/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctionsTests.java @@ -9,22 +9,54 @@ package org.elasticsearch.nativeaccess; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.common.logging.NodeNamePatternConverter; import org.elasticsearch.test.ESTestCase; +import java.lang.foreign.Arena; +import java.util.Arrays; import java.util.Optional; +import java.util.stream.IntStream; import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent; import static org.hamcrest.Matchers.not; -public class VectorSimilarityFunctionsTests extends ESTestCase { +public abstract class VectorSimilarityFunctionsTests extends ESTestCase { - final Optional vectorSimilarityFunctions; + static { + NodeNamePatternConverter.setGlobalNodeName("foo"); + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + public static final Class IAE = IllegalArgumentException.class; + public static final Class IOOBE = IndexOutOfBoundsException.class; + + protected static Arena arena; + + protected final int size; + protected final Optional vectorSimilarityFunctions; + + protected static Iterable parametersFactory() { + var dims1 = Arrays.stream(new int[] { 1, 2, 4, 6, 8, 12, 13, 16, 25, 31, 32, 33, 64, 100, 128, 207, 256, 300, 512, 702, 768 }); + var dims2 = Arrays.stream(new int[] { 1000, 1023, 1024, 1025, 2047, 2048, 2049, 4095, 4096, 4097 }); + return () -> IntStream.concat(dims1, dims2).boxed().map(i -> new Object[] { i }).iterator(); + } - public VectorSimilarityFunctionsTests() { + protected VectorSimilarityFunctionsTests(int size) { logger.info(platformMsg()); + this.size = size; vectorSimilarityFunctions = NativeAccess.instance().getVectorSimilarityFunctions(); } + public static void setup() { + arena = Arena.ofConfined(); + } + + public static void cleanup() { + arena.close(); + } + public void testSupported() { supported(); } @@ -59,4 +91,9 @@ public static String platformMsg() { var osName = System.getProperty("os.name"); return "JDK=" + jdkVersion + ", os=" + osName + ", arch=" + arch; } + + // Support for passing on-heap arrays/segments to native + protected static boolean supportsHeapSegments() { + return Runtime.version().feature() >= 22; + } } diff --git a/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryTests.java b/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt7uTests.java similarity index 85% rename from libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryTests.java rename to libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt7uTests.java index 04f80ba72891f..effad86d74a3e 100644 --- a/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryTests.java +++ b/libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt7uTests.java @@ -15,47 +15,33 @@ import org.junit.AfterClass; import org.junit.BeforeClass; -import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; -import java.util.stream.IntStream; import static org.hamcrest.Matchers.containsString; -public class JDKVectorLibraryTests extends VectorSimilarityFunctionsTests { +public class JDKVectorLibraryInt7uTests extends VectorSimilarityFunctionsTests { // bounds of the range of values that can be seen by int7 scalar quantized vectors static final byte MIN_INT7_VALUE = 0; static final byte MAX_INT7_VALUE = 127; - static final Class IAE = IllegalArgumentException.class; - static final Class IOOBE = IndexOutOfBoundsException.class; - - static final int[] VECTOR_DIMS = { 1, 4, 6, 8, 13, 16, 25, 31, 32, 33, 64, 100, 128, 207, 256, 300, 512, 702, 1023, 1024, 1025 }; - - final int size; - - static Arena arena; - - final double delta; - - public JDKVectorLibraryTests(int size) { - this.size = size; - this.delta = 1e-5 * size; // scale the delta with the size + public JDKVectorLibraryInt7uTests(int size) { + super(size); } @BeforeClass - public static void setup() { - arena = Arena.ofConfined(); + public static void beforeClass() { + VectorSimilarityFunctionsTests.setup(); } @AfterClass - public static void cleanup() { - arena.close(); + public static void afterClass() { + VectorSimilarityFunctionsTests.cleanup(); } @ParametersFactory public static Iterable parametersFactory() { - return () -> IntStream.of(VECTOR_DIMS).boxed().map(i -> new Object[] { i }).iterator(); + return VectorSimilarityFunctionsTests.parametersFactory(); } public void testInt7BinaryVectors() { @@ -79,7 +65,7 @@ public void testInt7BinaryVectors() { // dot product int expected = dotProductScalar(values[first], values[second]); assertEquals(expected, dotProduct7u(nativeSeg1, nativeSeg2, dims)); - if (testWithHeapSegments()) { + if (supportsHeapSegments()) { var heapSeg1 = MemorySegment.ofArray(values[first]); var heapSeg2 = MemorySegment.ofArray(values[second]); assertEquals(expected, dotProduct7u(heapSeg1, heapSeg2, dims)); @@ -90,7 +76,7 @@ public void testInt7BinaryVectors() { // square distance expected = squareDistanceScalar(values[first], values[second]); assertEquals(expected, squareDistance7u(nativeSeg1, nativeSeg2, dims)); - if (testWithHeapSegments()) { + if (supportsHeapSegments()) { var heapSeg1 = MemorySegment.ofArray(values[first]); var heapSeg2 = MemorySegment.ofArray(values[second]); assertEquals(expected, squareDistance7u(heapSeg1, heapSeg2, dims)); @@ -100,10 +86,6 @@ public void testInt7BinaryVectors() { } } - static boolean testWithHeapSegments() { - return Runtime.version().feature() >= 22; - } - public void testIllegalDims() { assumeTrue(notSupportedMsg(), supported()); var segment = arena.allocate((long) size * 3);