1313import org .elasticsearch .simdvec .internal .vectorization .ESVectorizationProvider ;
1414
1515import java .util .Arrays ;
16+ import java .util .function .ToDoubleBiFunction ;
17+ import java .util .function .ToLongBiFunction ;
1618
1719import static org .elasticsearch .simdvec .internal .vectorization .ESVectorUtilSupport .B_QUERY ;
20+ import static org .hamcrest .Matchers .closeTo ;
1821
1922public class ESVectorUtilTests extends BaseVectorizationTests {
2023
@@ -40,8 +43,18 @@ public void testIpFloatBit() {
4043 }
4144
4245 public void testIpFloatByte () {
43- float [] q = new float [16 ];
44- byte [] d = new byte [16 ];
46+ testIpFloatByteImpl (ESVectorUtil ::ipFloatByte );
47+ testIpFloatByteImpl (defaultedProvider .getVectorUtilSupport ()::ipFloatByte );
48+ testIpFloatByteImpl (defOrPanamaProvider .getVectorUtilSupport ()::ipFloatByte );
49+ }
50+
51+ private void testIpFloatByteImpl (ToDoubleBiFunction <float [], byte []> impl ) {
52+ int vectorSize = randomIntBetween (1 , 1024 );
53+ // scale the delta according to the vector size
54+ double delta = 1e-5 * vectorSize ;
55+
56+ float [] q = new float [vectorSize ];
57+ byte [] d = new byte [vectorSize ];
4558 for (int i = 0 ; i < q .length ; i ++) {
4659 q [i ] = random ().nextFloat ();
4760 }
@@ -51,7 +64,7 @@ public void testIpFloatByte() {
5164 for (int i = 0 ; i < q .length ; i ++) {
5265 expected += q [i ] * d [i ];
5366 }
54- assertEquals ( expected , ESVectorUtil . ipFloatByte (q , d ), 1e-6 );
67+ assertThat ( impl . applyAsDouble (q , d ), closeTo ( expected , delta ) );
5568 }
5669
5770 public void testBitAndCount () {
@@ -74,65 +87,57 @@ public void testBasicIpByteBin() {
7487 testBasicIpByteBinImpl (defOrPanamaProvider .getVectorUtilSupport ()::ipByteBinByte );
7588 }
7689
77- interface IpByteBin {
78- long apply (byte [] q , byte [] d );
79- }
80-
81- interface BitOps {
82- long apply (byte [] q , byte [] d );
83- }
84-
85- void testBasicBitAndImpl (BitOps bitAnd ) {
86- assertEquals (0 , bitAnd .apply (new byte [] { 0 }, new byte [] { 0 }));
87- assertEquals (0 , bitAnd .apply (new byte [] { 1 }, new byte [] { 0 }));
88- assertEquals (0 , bitAnd .apply (new byte [] { 0 }, new byte [] { 1 }));
89- assertEquals (1 , bitAnd .apply (new byte [] { 1 }, new byte [] { 1 }));
90+ void testBasicBitAndImpl (ToLongBiFunction <byte [], byte []> bitAnd ) {
91+ assertEquals (0 , bitAnd .applyAsLong (new byte [] { 0 }, new byte [] { 0 }));
92+ assertEquals (0 , bitAnd .applyAsLong (new byte [] { 1 }, new byte [] { 0 }));
93+ assertEquals (0 , bitAnd .applyAsLong (new byte [] { 0 }, new byte [] { 1 }));
94+ assertEquals (1 , bitAnd .applyAsLong (new byte [] { 1 }, new byte [] { 1 }));
9095 byte [] a = new byte [31 ];
9196 byte [] b = new byte [31 ];
9297 random ().nextBytes (a );
9398 random ().nextBytes (b );
9499 int expected = scalarBitAnd (a , b );
95- assertEquals (expected , bitAnd .apply (a , b ));
100+ assertEquals (expected , bitAnd .applyAsLong (a , b ));
96101 }
97102
98- void testBasicIpByteBinImpl (IpByteBin ipByteBinFunc ) {
99- assertEquals (15L , ipByteBinFunc .apply (new byte [] { 1 , 1 , 1 , 1 }, new byte [] { 1 }));
100- assertEquals (30L , ipByteBinFunc .apply (new byte [] { 1 , 2 , 1 , 2 , 1 , 2 , 1 , 2 }, new byte [] { 1 , 2 }));
103+ void testBasicIpByteBinImpl (ToLongBiFunction < byte [], byte []> ipByteBinFunc ) {
104+ assertEquals (15L , ipByteBinFunc .applyAsLong (new byte [] { 1 , 1 , 1 , 1 }, new byte [] { 1 }));
105+ assertEquals (30L , ipByteBinFunc .applyAsLong (new byte [] { 1 , 2 , 1 , 2 , 1 , 2 , 1 , 2 }, new byte [] { 1 , 2 }));
101106
102107 var d = new byte [] { 1 , 2 , 3 };
103108 var q = new byte [] { 1 , 2 , 3 , 1 , 2 , 3 , 1 , 2 , 3 , 1 , 2 , 3 };
104109 assert scalarIpByteBin (q , d ) == 60L ; // 4 + 8 + 16 + 32
105- assertEquals (60L , ipByteBinFunc .apply (q , d ));
110+ assertEquals (60L , ipByteBinFunc .applyAsLong (q , d ));
106111
107112 d = new byte [] { 1 , 2 , 3 , 4 };
108113 q = new byte [] { 1 , 2 , 3 , 4 , 1 , 2 , 3 , 4 , 1 , 2 , 3 , 4 , 1 , 2 , 3 , 4 };
109114 assert scalarIpByteBin (q , d ) == 75L ; // 5 + 10 + 20 + 40
110- assertEquals (75L , ipByteBinFunc .apply (q , d ));
115+ assertEquals (75L , ipByteBinFunc .applyAsLong (q , d ));
111116
112117 d = new byte [] { 1 , 2 , 3 , 4 , 5 };
113118 q = new byte [] { 1 , 2 , 3 , 4 , 5 , 1 , 2 , 3 , 4 , 5 , 1 , 2 , 3 , 4 , 5 , 1 , 2 , 3 , 4 , 5 };
114119 assert scalarIpByteBin (q , d ) == 105L ; // 7 + 14 + 28 + 56
115- assertEquals (105L , ipByteBinFunc .apply (q , d ));
120+ assertEquals (105L , ipByteBinFunc .applyAsLong (q , d ));
116121
117122 d = new byte [] { 1 , 2 , 3 , 4 , 5 , 6 };
118123 q = new byte [] { 1 , 2 , 3 , 4 , 5 , 6 , 1 , 2 , 3 , 4 , 5 , 6 , 1 , 2 , 3 , 4 , 5 , 6 , 1 , 2 , 3 , 4 , 5 , 6 };
119124 assert scalarIpByteBin (q , d ) == 135L ; // 9 + 18 + 36 + 72
120- assertEquals (135L , ipByteBinFunc .apply (q , d ));
125+ assertEquals (135L , ipByteBinFunc .applyAsLong (q , d ));
121126
122127 d = new byte [] { 1 , 2 , 3 , 4 , 5 , 6 , 7 };
123128 q = new byte [] { 1 , 2 , 3 , 4 , 5 , 6 , 7 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 1 , 2 , 3 , 4 , 5 , 6 , 7 };
124129 assert scalarIpByteBin (q , d ) == 180L ; // 12 + 24 + 48 + 96
125- assertEquals (180L , ipByteBinFunc .apply (q , d ));
130+ assertEquals (180L , ipByteBinFunc .applyAsLong (q , d ));
126131
127132 d = new byte [] { 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 };
128133 q = new byte [] { 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 };
129134 assert scalarIpByteBin (q , d ) == 195L ; // 13 + 26 + 52 + 104
130- assertEquals (195L , ipByteBinFunc .apply (q , d ));
135+ assertEquals (195L , ipByteBinFunc .applyAsLong (q , d ));
131136
132137 d = new byte [] { 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 };
133138 q = new byte [] { 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 };
134139 assert scalarIpByteBin (q , d ) == 225L ; // 15 + 30 + 60 + 120
135- assertEquals (225L , ipByteBinFunc .apply (q , d ));
140+ assertEquals (225L , ipByteBinFunc .applyAsLong (q , d ));
136141 }
137142
138143 public void testIpByteBin () {
@@ -141,23 +146,23 @@ public void testIpByteBin() {
141146 testIpByteBinImpl (defOrPanamaProvider .getVectorUtilSupport ()::ipByteBinByte );
142147 }
143148
144- void testIpByteBinImpl (IpByteBin ipByteBinFunc ) {
149+ void testIpByteBinImpl (ToLongBiFunction < byte [], byte []> ipByteBinFunc ) {
145150 int iterations = atLeast (50 );
146151 for (int i = 0 ; i < iterations ; i ++) {
147152 int size = random ().nextInt (5000 );
148153 var d = new byte [size ];
149154 var q = new byte [size * B_QUERY ];
150155 random ().nextBytes (d );
151156 random ().nextBytes (q );
152- assertEquals (scalarIpByteBin (q , d ), ipByteBinFunc .apply (q , d ));
157+ assertEquals (scalarIpByteBin (q , d ), ipByteBinFunc .applyAsLong (q , d ));
153158
154159 Arrays .fill (d , Byte .MAX_VALUE );
155160 Arrays .fill (q , Byte .MAX_VALUE );
156- assertEquals (scalarIpByteBin (q , d ), ipByteBinFunc .apply (q , d ));
161+ assertEquals (scalarIpByteBin (q , d ), ipByteBinFunc .applyAsLong (q , d ));
157162
158163 Arrays .fill (d , Byte .MIN_VALUE );
159164 Arrays .fill (q , Byte .MIN_VALUE );
160- assertEquals (scalarIpByteBin (q , d ), ipByteBinFunc .apply (q , d ));
165+ assertEquals (scalarIpByteBin (q , d ), ipByteBinFunc .applyAsLong (q , d ));
161166 }
162167 }
163168
0 commit comments