Skip to content

Commit ddac590

Browse files
committed
Add some tests for the panama implementation
1 parent ff2e0e2 commit ddac590

File tree

1 file changed

+33
-31
lines changed

1 file changed

+33
-31
lines changed

libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
1414

1515
import java.util.Arrays;
16+
import java.util.function.ToDoubleBiFunction;
17+
import java.util.function.ToLongBiFunction;
1618

1719
import static org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport.B_QUERY;
1820

@@ -40,8 +42,16 @@ public void testIpFloatBit() {
4042
}
4143

4244
public void testIpFloatByte() {
43-
float[] q = new float[16];
44-
byte[] d = new byte[16];
45+
testIpFloatByteImpl(ESVectorUtil::ipFloatByte);
46+
testIpFloatByteImpl(defaultedProvider.getVectorUtilSupport()::ipFloatByte);
47+
testIpFloatByteImpl(defOrPanamaProvider.getVectorUtilSupport()::ipFloatByte);
48+
}
49+
50+
private void testIpFloatByteImpl(ToDoubleBiFunction<float[], byte[]> impl) {
51+
int vectorSize = randomIntBetween(1, 1024);
52+
53+
float[] q = new float[vectorSize];
54+
byte[] d = new byte[vectorSize];
4555
for (int i = 0; i < q.length; i++) {
4656
q[i] = random().nextFloat();
4757
}
@@ -51,7 +61,7 @@ public void testIpFloatByte() {
5161
for (int i = 0; i < q.length; i++) {
5262
expected += q[i] * d[i];
5363
}
54-
assertEquals(expected, ESVectorUtil.ipFloatByte(q, d), 1e-6);
64+
assertEquals(expected, impl.applyAsDouble(q, d), 1e-2);
5565
}
5666

5767
public void testBitAndCount() {
@@ -74,65 +84,57 @@ public void testBasicIpByteBin() {
7484
testBasicIpByteBinImpl(defOrPanamaProvider.getVectorUtilSupport()::ipByteBinByte);
7585
}
7686

77-
interface IpByteBin {
78-
long apply(byte[] q, byte[] d);
79-
}
80-
81-
interface BitOps {
82-
long apply(byte[] q, byte[] d);
83-
}
84-
85-
void testBasicBitAndImpl(BitOps bitAnd) {
86-
assertEquals(0, bitAnd.apply(new byte[] { 0 }, new byte[] { 0 }));
87-
assertEquals(0, bitAnd.apply(new byte[] { 1 }, new byte[] { 0 }));
88-
assertEquals(0, bitAnd.apply(new byte[] { 0 }, new byte[] { 1 }));
89-
assertEquals(1, bitAnd.apply(new byte[] { 1 }, new byte[] { 1 }));
87+
void testBasicBitAndImpl(ToLongBiFunction<byte[], byte[]> bitAnd) {
88+
assertEquals(0, bitAnd.applyAsLong(new byte[] { 0 }, new byte[] { 0 }));
89+
assertEquals(0, bitAnd.applyAsLong(new byte[] { 1 }, new byte[] { 0 }));
90+
assertEquals(0, bitAnd.applyAsLong(new byte[] { 0 }, new byte[] { 1 }));
91+
assertEquals(1, bitAnd.applyAsLong(new byte[] { 1 }, new byte[] { 1 }));
9092
byte[] a = new byte[31];
9193
byte[] b = new byte[31];
9294
random().nextBytes(a);
9395
random().nextBytes(b);
9496
int expected = scalarBitAnd(a, b);
95-
assertEquals(expected, bitAnd.apply(a, b));
97+
assertEquals(expected, bitAnd.applyAsLong(a, b));
9698
}
9799

98-
void testBasicIpByteBinImpl(IpByteBin ipByteBinFunc) {
99-
assertEquals(15L, ipByteBinFunc.apply(new byte[] { 1, 1, 1, 1 }, new byte[] { 1 }));
100-
assertEquals(30L, ipByteBinFunc.apply(new byte[] { 1, 2, 1, 2, 1, 2, 1, 2 }, new byte[] { 1, 2 }));
100+
void testBasicIpByteBinImpl(ToLongBiFunction<byte[], byte[]> ipByteBinFunc) {
101+
assertEquals(15L, ipByteBinFunc.applyAsLong(new byte[] { 1, 1, 1, 1 }, new byte[] { 1 }));
102+
assertEquals(30L, ipByteBinFunc.applyAsLong(new byte[] { 1, 2, 1, 2, 1, 2, 1, 2 }, new byte[] { 1, 2 }));
101103

102104
var d = new byte[] { 1, 2, 3 };
103105
var q = new byte[] { 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3 };
104106
assert scalarIpByteBin(q, d) == 60L; // 4 + 8 + 16 + 32
105-
assertEquals(60L, ipByteBinFunc.apply(q, d));
107+
assertEquals(60L, ipByteBinFunc.applyAsLong(q, d));
106108

107109
d = new byte[] { 1, 2, 3, 4 };
108110
q = new byte[] { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 };
109111
assert scalarIpByteBin(q, d) == 75L; // 5 + 10 + 20 + 40
110-
assertEquals(75L, ipByteBinFunc.apply(q, d));
112+
assertEquals(75L, ipByteBinFunc.applyAsLong(q, d));
111113

112114
d = new byte[] { 1, 2, 3, 4, 5 };
113115
q = new byte[] { 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5 };
114116
assert scalarIpByteBin(q, d) == 105L; // 7 + 14 + 28 + 56
115-
assertEquals(105L, ipByteBinFunc.apply(q, d));
117+
assertEquals(105L, ipByteBinFunc.applyAsLong(q, d));
116118

117119
d = new byte[] { 1, 2, 3, 4, 5, 6 };
118120
q = new byte[] { 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6 };
119121
assert scalarIpByteBin(q, d) == 135L; // 9 + 18 + 36 + 72
120-
assertEquals(135L, ipByteBinFunc.apply(q, d));
122+
assertEquals(135L, ipByteBinFunc.applyAsLong(q, d));
121123

122124
d = new byte[] { 1, 2, 3, 4, 5, 6, 7 };
123125
q = new byte[] { 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7 };
124126
assert scalarIpByteBin(q, d) == 180L; // 12 + 24 + 48 + 96
125-
assertEquals(180L, ipByteBinFunc.apply(q, d));
127+
assertEquals(180L, ipByteBinFunc.applyAsLong(q, d));
126128

127129
d = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 };
128130
q = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8 };
129131
assert scalarIpByteBin(q, d) == 195L; // 13 + 26 + 52 + 104
130-
assertEquals(195L, ipByteBinFunc.apply(q, d));
132+
assertEquals(195L, ipByteBinFunc.applyAsLong(q, d));
131133

132134
d = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9 };
133135
q = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
134136
assert scalarIpByteBin(q, d) == 225L; // 15 + 30 + 60 + 120
135-
assertEquals(225L, ipByteBinFunc.apply(q, d));
137+
assertEquals(225L, ipByteBinFunc.applyAsLong(q, d));
136138
}
137139

138140
public void testIpByteBin() {
@@ -141,23 +143,23 @@ public void testIpByteBin() {
141143
testIpByteBinImpl(defOrPanamaProvider.getVectorUtilSupport()::ipByteBinByte);
142144
}
143145

144-
void testIpByteBinImpl(IpByteBin ipByteBinFunc) {
146+
void testIpByteBinImpl(ToLongBiFunction<byte[], byte[]> ipByteBinFunc) {
145147
int iterations = atLeast(50);
146148
for (int i = 0; i < iterations; i++) {
147149
int size = random().nextInt(5000);
148150
var d = new byte[size];
149151
var q = new byte[size * B_QUERY];
150152
random().nextBytes(d);
151153
random().nextBytes(q);
152-
assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d));
154+
assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.applyAsLong(q, d));
153155

154156
Arrays.fill(d, Byte.MAX_VALUE);
155157
Arrays.fill(q, Byte.MAX_VALUE);
156-
assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d));
158+
assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.applyAsLong(q, d));
157159

158160
Arrays.fill(d, Byte.MIN_VALUE);
159161
Arrays.fill(q, Byte.MIN_VALUE);
160-
assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.apply(q, d));
162+
assertEquals(scalarIpByteBin(q, d), ipByteBinFunc.applyAsLong(q, d));
161163
}
162164
}
163165

0 commit comments

Comments
 (0)