1313import org .elasticsearch .simdvec .internal .vectorization .ESVectorizationProvider ;
1414
1515import java .util .Arrays ;
16+ import java .util .function .ToDoubleBiFunction ;
17+ import java .util .function .ToLongBiFunction ;
1618
1719import static org .elasticsearch .simdvec .internal .vectorization .ESVectorUtilSupport .B_QUERY ;
1820
@@ -40,8 +42,16 @@ public void testIpFloatBit() {
4042 }
4143
4244 public void testIpFloatByte () {
43- float [] q = new float [16 ];
44- byte [] d = new byte [16 ];
45+ testIpFloatByteImpl (ESVectorUtil ::ipFloatByte );
46+ testIpFloatByteImpl (defaultedProvider .getVectorUtilSupport ()::ipFloatByte );
47+ testIpFloatByteImpl (defOrPanamaProvider .getVectorUtilSupport ()::ipFloatByte );
48+ }
49+
50+ private void testIpFloatByteImpl (ToDoubleBiFunction <float [], byte []> impl ) {
51+ int vectorSize = randomIntBetween (1 , 1024 );
52+
53+ float [] q = new float [vectorSize ];
54+ byte [] d = new byte [vectorSize ];
4555 for (int i = 0 ; i < q .length ; i ++) {
4656 q [i ] = random ().nextFloat ();
4757 }
@@ -51,7 +61,7 @@ public void testIpFloatByte() {
5161 for (int i = 0 ; i < q .length ; i ++) {
5262 expected += q [i ] * d [i ];
5363 }
54- assertEquals (expected , ESVectorUtil . ipFloatByte (q , d ), 1e-6 );
64+ assertEquals (expected , impl . applyAsDouble (q , d ), 1e-2 );
5565 }
5666
5767 public void testBitAndCount () {
@@ -74,65 +84,57 @@ public void testBasicIpByteBin() {
7484 testBasicIpByteBinImpl (defOrPanamaProvider .getVectorUtilSupport ()::ipByteBinByte );
7585 }
7686
77- interface IpByteBin {
78- long apply (byte [] q , byte [] d );
79- }
80-
81- interface BitOps {
82- long apply (byte [] q , byte [] d );
83- }
84-
85- void testBasicBitAndImpl (BitOps bitAnd ) {
86- assertEquals (0 , bitAnd .apply (new byte [] { 0 }, new byte [] { 0 }));
87- assertEquals (0 , bitAnd .apply (new byte [] { 1 }, new byte [] { 0 }));
88- assertEquals (0 , bitAnd .apply (new byte [] { 0 }, new byte [] { 1 }));
89- assertEquals (1 , bitAnd .apply (new byte [] { 1 }, new byte [] { 1 }));
87+ void testBasicBitAndImpl (ToLongBiFunction <byte [], byte []> bitAnd ) {
88+ assertEquals (0 , bitAnd .applyAsLong (new byte [] { 0 }, new byte [] { 0 }));
89+ assertEquals (0 , bitAnd .applyAsLong (new byte [] { 1 }, new byte [] { 0 }));
90+ assertEquals (0 , bitAnd .applyAsLong (new byte [] { 0 }, new byte [] { 1 }));
91+ assertEquals (1 , bitAnd .applyAsLong (new byte [] { 1 }, new byte [] { 1 }));
9092 byte [] a = new byte [31 ];
9193 byte [] b = new byte [31 ];
9294 random ().nextBytes (a );
9395 random ().nextBytes (b );
9496 int expected = scalarBitAnd (a , b );
95- assertEquals (expected , bitAnd .apply (a , b ));
97+ assertEquals (expected , bitAnd .applyAsLong (a , b ));
9698 }
9799
98- void testBasicIpByteBinImpl (IpByteBin ipByteBinFunc ) {
99- assertEquals (15L , ipByteBinFunc .apply (new byte [] { 1 , 1 , 1 , 1 }, new byte [] { 1 }));
100- assertEquals (30L , ipByteBinFunc .apply (new byte [] { 1 , 2 , 1 , 2 , 1 , 2 , 1 , 2 }, new byte [] { 1 , 2 }));
100+ void testBasicIpByteBinImpl (ToLongBiFunction < byte [], byte []> ipByteBinFunc ) {
101+ assertEquals (15L , ipByteBinFunc .applyAsLong (new byte [] { 1 , 1 , 1 , 1 }, new byte [] { 1 }));
102+ assertEquals (30L , ipByteBinFunc .applyAsLong (new byte [] { 1 , 2 , 1 , 2 , 1 , 2 , 1 , 2 }, new byte [] { 1 , 2 }));
101103
102104 var d = new byte [] { 1 , 2 , 3 };
103105 var q = new byte [] { 1 , 2 , 3 , 1 , 2 , 3 , 1 , 2 , 3 , 1 , 2 , 3 };
104106 assert scalarIpByteBin (q , d ) == 60L ; // 4 + 8 + 16 + 32
105- assertEquals (60L , ipByteBinFunc .apply (q , d ));
107+ assertEquals (60L , ipByteBinFunc .applyAsLong (q , d ));
106108
107109 d = new byte [] { 1 , 2 , 3 , 4 };
108110 q = new byte [] { 1 , 2 , 3 , 4 , 1 , 2 , 3 , 4 , 1 , 2 , 3 , 4 , 1 , 2 , 3 , 4 };
109111 assert scalarIpByteBin (q , d ) == 75L ; // 5 + 10 + 20 + 40
110- assertEquals (75L , ipByteBinFunc .apply (q , d ));
112+ assertEquals (75L , ipByteBinFunc .applyAsLong (q , d ));
111113
112114 d = new byte [] { 1 , 2 , 3 , 4 , 5 };
113115 q = new byte [] { 1 , 2 , 3 , 4 , 5 , 1 , 2 , 3 , 4 , 5 , 1 , 2 , 3 , 4 , 5 , 1 , 2 , 3 , 4 , 5 };
114116 assert scalarIpByteBin (q , d ) == 105L ; // 7 + 14 + 28 + 56
115- assertEquals (105L , ipByteBinFunc .apply (q , d ));
117+ assertEquals (105L , ipByteBinFunc .applyAsLong (q , d ));
116118
117119 d = new byte [] { 1 , 2 , 3 , 4 , 5 , 6 };
118120 q = new byte [] { 1 , 2 , 3 , 4 , 5 , 6 , 1 , 2 , 3 , 4 , 5 , 6 , 1 , 2 , 3 , 4 , 5 , 6 , 1 , 2 , 3 , 4 , 5 , 6 };
119121 assert scalarIpByteBin (q , d ) == 135L ; // 9 + 18 + 36 + 72
120- assertEquals (135L , ipByteBinFunc .apply (q , d ));
122+ assertEquals (135L , ipByteBinFunc .applyAsLong (q , d ));
121123
122124 d = new byte [] { 1 , 2 , 3 , 4 , 5 , 6 , 7 };
123125 q = new byte [] { 1 , 2 , 3 , 4 , 5 , 6 , 7 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 1 , 2 , 3 , 4 , 5 , 6 , 7 };
124126 assert scalarIpByteBin (q , d ) == 180L ; // 12 + 24 + 48 + 96
125- assertEquals (180L , ipByteBinFunc .apply (q , d ));
127+ assertEquals (180L , ipByteBinFunc .applyAsLong (q , d ));
126128
127129 d = new byte [] { 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 };
128130 q = new byte [] { 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 };
129131 assert scalarIpByteBin (q , d ) == 195L ; // 13 + 26 + 52 + 104
130- assertEquals (195L , ipByteBinFunc .apply (q , d ));
132+ assertEquals (195L , ipByteBinFunc .applyAsLong (q , d ));
131133
132134 d = new byte [] { 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 };
133135 q = new byte [] { 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 };
134136 assert scalarIpByteBin (q , d ) == 225L ; // 15 + 30 + 60 + 120
135- assertEquals (225L , ipByteBinFunc .apply (q , d ));
137+ assertEquals (225L , ipByteBinFunc .applyAsLong (q , d ));
136138 }
137139
138140 public void testIpByteBin () {
@@ -141,23 +143,23 @@ public void testIpByteBin() {
141143 testIpByteBinImpl (defOrPanamaProvider .getVectorUtilSupport ()::ipByteBinByte );
142144 }
143145
144- void testIpByteBinImpl (IpByteBin ipByteBinFunc ) {
146+ void testIpByteBinImpl (ToLongBiFunction < byte [], byte []> ipByteBinFunc ) {
145147 int iterations = atLeast (50 );
146148 for (int i = 0 ; i < iterations ; i ++) {
147149 int size = random ().nextInt (5000 );
148150 var d = new byte [size ];
149151 var q = new byte [size * B_QUERY ];
150152 random ().nextBytes (d );
151153 random ().nextBytes (q );
152- assertEquals (scalarIpByteBin (q , d ), ipByteBinFunc .apply (q , d ));
154+ assertEquals (scalarIpByteBin (q , d ), ipByteBinFunc .applyAsLong (q , d ));
153155
154156 Arrays .fill (d , Byte .MAX_VALUE );
155157 Arrays .fill (q , Byte .MAX_VALUE );
156- assertEquals (scalarIpByteBin (q , d ), ipByteBinFunc .apply (q , d ));
158+ assertEquals (scalarIpByteBin (q , d ), ipByteBinFunc .applyAsLong (q , d ));
157159
158160 Arrays .fill (d , Byte .MIN_VALUE );
159161 Arrays .fill (q , Byte .MIN_VALUE );
160- assertEquals (scalarIpByteBin (q , d ), ipByteBinFunc .apply (q , d ));
162+ assertEquals (scalarIpByteBin (q , d ), ipByteBinFunc .applyAsLong (q , d ));
161163 }
162164 }
163165
0 commit comments