Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,54 @@

package org.elasticsearch.nativeaccess;

import org.elasticsearch.common.logging.LogConfigurator;
import org.elasticsearch.common.logging.NodeNamePatternConverter;
import org.elasticsearch.test.ESTestCase;

import java.lang.foreign.Arena;
import java.util.Arrays;
import java.util.Optional;
import java.util.stream.IntStream;

import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent;
import static org.hamcrest.Matchers.not;

public class VectorSimilarityFunctionsTests extends ESTestCase {
public abstract class VectorSimilarityFunctionsTests extends ESTestCase {

final Optional<VectorSimilarityFunctions> vectorSimilarityFunctions;
static {
NodeNamePatternConverter.setGlobalNodeName("foo");
LogConfigurator.loadLog4jPlugins();
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
}

public static final Class<IllegalArgumentException> IAE = IllegalArgumentException.class;
public static final Class<IndexOutOfBoundsException> IOOBE = IndexOutOfBoundsException.class;

protected static Arena arena;

protected final int size;
protected final Optional<VectorSimilarityFunctions> vectorSimilarityFunctions;

protected static Iterable<Object[]> parametersFactory() {
var dims1 = Arrays.stream(new int[] { 1, 2, 4, 6, 8, 12, 13, 16, 25, 31, 32, 33, 64, 100, 128, 207, 256, 300, 512, 702, 768 });
var dims2 = Arrays.stream(new int[] { 1000, 1023, 1024, 1025, 2047, 2048, 2049, 4095, 4096, 4097 });
return () -> IntStream.concat(dims1, dims2).boxed().map(i -> new Object[] { i }).iterator();
}

public VectorSimilarityFunctionsTests() {
protected VectorSimilarityFunctionsTests(int size) {
logger.info(platformMsg());
this.size = size;
vectorSimilarityFunctions = NativeAccess.instance().getVectorSimilarityFunctions();
}

public static void setup() {
arena = Arena.ofConfined();
}

public static void cleanup() {
arena.close();
}

public void testSupported() {
supported();
}
Expand Down Expand Up @@ -59,4 +91,9 @@ public static String platformMsg() {
var osName = System.getProperty("os.name");
return "JDK=" + jdkVersion + ", os=" + osName + ", arch=" + arch;
}

// Support for passing on-heap arrays/segments to native
protected static boolean supportsHeapSegments() {
return Runtime.version().feature() >= 22;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,47 +15,33 @@
import org.junit.AfterClass;
import org.junit.BeforeClass;

import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.util.stream.IntStream;

import static org.hamcrest.Matchers.containsString;

public class JDKVectorLibraryTests extends VectorSimilarityFunctionsTests {
public class JDKVectorLibraryInt7uTests extends VectorSimilarityFunctionsTests {

// bounds of the range of values that can be seen by int7 scalar quantized vectors
static final byte MIN_INT7_VALUE = 0;
static final byte MAX_INT7_VALUE = 127;

static final Class<IllegalArgumentException> IAE = IllegalArgumentException.class;
static final Class<IndexOutOfBoundsException> IOOBE = IndexOutOfBoundsException.class;

static final int[] VECTOR_DIMS = { 1, 4, 6, 8, 13, 16, 25, 31, 32, 33, 64, 100, 128, 207, 256, 300, 512, 702, 1023, 1024, 1025 };

final int size;

static Arena arena;

final double delta;

public JDKVectorLibraryTests(int size) {
this.size = size;
this.delta = 1e-5 * size; // scale the delta with the size
public JDKVectorLibraryInt7uTests(int size) {
super(size);
}

@BeforeClass
public static void setup() {
arena = Arena.ofConfined();
public static void beforeClass() {
VectorSimilarityFunctionsTests.setup();
}

@AfterClass
public static void cleanup() {
arena.close();
public static void afterClass() {
VectorSimilarityFunctionsTests.cleanup();
}

@ParametersFactory
public static Iterable<Object[]> parametersFactory() {
return () -> IntStream.of(VECTOR_DIMS).boxed().map(i -> new Object[] { i }).iterator();
return VectorSimilarityFunctionsTests.parametersFactory();
}

public void testInt7BinaryVectors() {
Expand All @@ -79,7 +65,7 @@ public void testInt7BinaryVectors() {
// dot product
int expected = dotProductScalar(values[first], values[second]);
assertEquals(expected, dotProduct7u(nativeSeg1, nativeSeg2, dims));
if (testWithHeapSegments()) {
if (supportsHeapSegments()) {
var heapSeg1 = MemorySegment.ofArray(values[first]);
var heapSeg2 = MemorySegment.ofArray(values[second]);
assertEquals(expected, dotProduct7u(heapSeg1, heapSeg2, dims));
Expand All @@ -90,7 +76,7 @@ public void testInt7BinaryVectors() {
// square distance
expected = squareDistanceScalar(values[first], values[second]);
assertEquals(expected, squareDistance7u(nativeSeg1, nativeSeg2, dims));
if (testWithHeapSegments()) {
if (supportsHeapSegments()) {
var heapSeg1 = MemorySegment.ofArray(values[first]);
var heapSeg2 = MemorySegment.ofArray(values[second]);
assertEquals(expected, squareDistance7u(heapSeg1, heapSeg2, dims));
Expand All @@ -100,10 +86,6 @@ public void testInt7BinaryVectors() {
}
}

static boolean testWithHeapSegments() {
return Runtime.version().feature() >= 22;
}

public void testIllegalDims() {
assumeTrue(notSupportedMsg(), supported());
var segment = arena.allocate((long) size * 3);
Expand Down