1717
1818import java .lang .foreign .MemorySegment ;
1919
20+ import static java .lang .foreign .ValueLayout .JAVA_FLOAT_UNALIGNED ;
2021import static org .hamcrest .Matchers .containsString ;
2122
2223public class JDKVectorLibraryInt7uTests extends VectorSimilarityFunctionsTests {
@@ -71,6 +72,11 @@ public void testInt7BinaryVectors() {
7172 assertEquals (expected , dotProduct7u (heapSeg1 , heapSeg2 , dims ));
7273 assertEquals (expected , dotProduct7u (nativeSeg1 , heapSeg2 , dims ));
7374 assertEquals (expected , dotProduct7u (heapSeg1 , nativeSeg2 , dims ));
75+
76+ // trivial bulk with a single vector
77+ float [] bulkScore = new float [1 ];
78+ dotProduct7uBulk (nativeSeg1 , nativeSeg2 , dims , 1 , MemorySegment .ofArray (bulkScore ));
79+ assertEquals (expected , bulkScore [0 ], 0f );
7480 }
7581
7682 // square distance
@@ -86,6 +92,32 @@ public void testInt7BinaryVectors() {
8692 }
8793 }
8894
95+ public void testInt7uBulk () {
96+ assumeTrue (notSupportedMsg (), supported ());
97+ final int dims = size ;
98+ final int numVecs = randomIntBetween (2 , 101 );
99+ var values = new byte [numVecs ][dims ];
100+ var segment = arena .allocate ((long ) dims * numVecs );
101+ for (int i = 0 ; i < numVecs ; i ++) {
102+ randomBytesBetween (values [i ], MIN_INT7_VALUE , MAX_INT7_VALUE );
103+ MemorySegment .copy (MemorySegment .ofArray (values [i ]), 0L , segment , (long ) i * dims , dims );
104+ }
105+ int queryOrd = randomInt (numVecs - 1 );
106+ float [] expectedScores = new float [numVecs ];
107+ dotProductBulkScalar (values [queryOrd ], values , expectedScores );
108+
109+ var nativeQuerySeg = segment .asSlice ((long ) queryOrd * dims , dims );
110+ var bulkScoresSeg = arena .allocate ((long ) numVecs * Float .BYTES );
111+ dotProduct7uBulk (segment , nativeQuerySeg , dims , numVecs , bulkScoresSeg );
112+ assertScoresEquals (expectedScores , bulkScoresSeg );
113+
114+ if (supportsHeapSegments ()) {
115+ float [] bulkScores = new float [numVecs ];
116+ dotProduct7uBulk (segment , nativeQuerySeg , dims , numVecs , MemorySegment .ofArray (bulkScores ));
117+ assertArrayEquals (expectedScores , bulkScores , 0f );
118+ }
119+ }
120+
89121 public void testIllegalDims () {
90122 assumeTrue (notSupportedMsg (), supported ());
91123 var segment = arena .allocate ((long ) size * 3 );
@@ -109,6 +141,26 @@ public void testIllegalDims() {
109141 assertThat (e6 .getMessage (), containsString ("out of bounds for length" ));
110142 }
111143
144+ public void testBulkIllegalDims () {
145+ assumeTrue (notSupportedMsg (), supported ());
146+ var segA = arena .allocate ((long ) size * 3 );
147+ var segB = arena .allocate ((long ) size * 3 );
148+ var segS = arena .allocate ((long ) size * Float .BYTES );
149+
150+ var e1 = expectThrows (IOOBE , () -> dotProduct7uBulk (segA , segB , size , 4 , segS ));
151+ assertThat (e1 .getMessage (), containsString ("out of bounds for length" ));
152+
153+ var e2 = expectThrows (IOOBE , () -> dotProduct7uBulk (segA , segB , size , -1 , segS ));
154+ assertThat (e2 .getMessage (), containsString ("out of bounds for length" ));
155+
156+ var e3 = expectThrows (IOOBE , () -> dotProduct7uBulk (segA , segB , -1 , 3 , segS ));
157+ assertThat (e3 .getMessage (), containsString ("out of bounds for length" ));
158+
159+ var tooSmall = arena .allocate ((long ) 3 * Float .BYTES - 1 );
160+ var e4 = expectThrows (IOOBE , () -> dotProduct7uBulk (segA , segB , size , 3 , tooSmall ));
161+ assertThat (e4 .getMessage (), containsString ("out of bounds for length" ));
162+ }
163+
112164 int dotProduct7u (MemorySegment a , MemorySegment b , int length ) {
113165 try {
114166 return (int ) getVectorDistance ().dotProductHandle7u ().invokeExact (a , b , length );
@@ -137,6 +189,20 @@ int squareDistance7u(MemorySegment a, MemorySegment b, int length) {
137189 }
138190 }
139191
192+ void dotProduct7uBulk (MemorySegment a , MemorySegment b , int dims , int count , MemorySegment result ) {
193+ try {
194+ getVectorDistance ().dotProductHandle7uBulk ().invokeExact (a , b , dims , count , result );
195+ } catch (Throwable e ) {
196+ if (e instanceof Error err ) {
197+ throw err ;
198+ } else if (e instanceof RuntimeException re ) {
199+ throw re ;
200+ } else {
201+ throw new RuntimeException (e );
202+ }
203+ }
204+ }
205+
140206 /** Computes the dot product of the given vectors a and b. */
141207 static int dotProductScalar (byte [] a , byte [] b ) {
142208 int res = 0 ;
@@ -156,4 +222,18 @@ static int squareDistanceScalar(byte[] a, byte[] b) {
156222 }
157223 return squareSum ;
158224 }
225+
226+ static void dotProductBulkScalar (byte [] query , byte [][] data , float [] scores ) {
227+ for (int i = 0 ; i < data .length ; i ++) {
228+ scores [i ] = dotProductScalar (query , data [i ]);
229+ }
230+ }
231+
232+ static void assertScoresEquals (float [] expectedScores , MemorySegment expectedScoresSeg ) {
233+ assert expectedScores .length == (expectedScoresSeg .byteSize () / Float .BYTES );
234+ for (int i = 0 ; i < expectedScores .length ; i ++) {
235+ assertEquals (expectedScores [i ], expectedScoresSeg .get (JAVA_FLOAT_UNALIGNED , i * Float .BYTES ), 0f );
236+ }
237+ }
238+
159239}
0 commit comments