1313import  org .elasticsearch .simdvec .internal .vectorization .ESVectorizationProvider ;
1414
1515import  java .util .Arrays ;
16+ import  java .util .function .ToDoubleBiFunction ;
17+ import  java .util .function .ToLongBiFunction ;
1618
1719import  static  org .elasticsearch .simdvec .internal .vectorization .ESVectorUtilSupport .B_QUERY ;
20+ import  static  org .hamcrest .Matchers .closeTo ;
1821
1922public  class  ESVectorUtilTests  extends  BaseVectorizationTests  {
2023
@@ -40,8 +43,18 @@ public void testIpFloatBit() {
4043    }
4144
4245    public  void  testIpFloatByte () {
43-         float [] q  = new  float [16 ];
44-         byte [] d  = new  byte [16 ];
46+         testIpFloatByteImpl (ESVectorUtil ::ipFloatByte );
47+         testIpFloatByteImpl (defaultedProvider .getVectorUtilSupport ()::ipFloatByte );
48+         testIpFloatByteImpl (defOrPanamaProvider .getVectorUtilSupport ()::ipFloatByte );
49+     }
50+ 
51+     private  void  testIpFloatByteImpl (ToDoubleBiFunction <float [], byte []> impl ) {
52+         int  vectorSize  = randomIntBetween (1 , 1024 );
53+         // scale the delta according to the vector size 
54+         double  delta  = 1e-5  * vectorSize ;
55+ 
56+         float [] q  = new  float [vectorSize ];
57+         byte [] d  = new  byte [vectorSize ];
4558        for  (int  i  = 0 ; i  < q .length ; i ++) {
4659            q [i ] = random ().nextFloat ();
4760        }
@@ -51,7 +64,7 @@ public void testIpFloatByte() {
5164        for  (int  i  = 0 ; i  < q .length ; i ++) {
5265            expected  += q [i ] * d [i ];
5366        }
54-         assertEquals ( expected ,  ESVectorUtil . ipFloatByte (q , d ), 1e-6 );
67+         assertThat ( impl . applyAsDouble (q , d ), closeTo ( expected ,  delta ) );
5568    }
5669
5770    public  void  testBitAndCount () {
@@ -74,65 +87,57 @@ public void testBasicIpByteBin() {
7487        testBasicIpByteBinImpl (defOrPanamaProvider .getVectorUtilSupport ()::ipByteBinByte );
7588    }
7689
77-     interface  IpByteBin  {
78-         long  apply (byte [] q , byte [] d );
79-     }
80- 
81-     interface  BitOps  {
82-         long  apply (byte [] q , byte [] d );
83-     }
84- 
85-     void  testBasicBitAndImpl (BitOps  bitAnd ) {
86-         assertEquals (0 , bitAnd .apply (new  byte [] { 0  }, new  byte [] { 0  }));
87-         assertEquals (0 , bitAnd .apply (new  byte [] { 1  }, new  byte [] { 0  }));
88-         assertEquals (0 , bitAnd .apply (new  byte [] { 0  }, new  byte [] { 1  }));
89-         assertEquals (1 , bitAnd .apply (new  byte [] { 1  }, new  byte [] { 1  }));
90+     void  testBasicBitAndImpl (ToLongBiFunction <byte [], byte []> bitAnd ) {
91+         assertEquals (0 , bitAnd .applyAsLong (new  byte [] { 0  }, new  byte [] { 0  }));
92+         assertEquals (0 , bitAnd .applyAsLong (new  byte [] { 1  }, new  byte [] { 0  }));
93+         assertEquals (0 , bitAnd .applyAsLong (new  byte [] { 0  }, new  byte [] { 1  }));
94+         assertEquals (1 , bitAnd .applyAsLong (new  byte [] { 1  }, new  byte [] { 1  }));
9095        byte [] a  = new  byte [31 ];
9196        byte [] b  = new  byte [31 ];
9297        random ().nextBytes (a );
9398        random ().nextBytes (b );
9499        int  expected  = scalarBitAnd (a , b );
95-         assertEquals (expected , bitAnd .apply (a , b ));
100+         assertEquals (expected , bitAnd .applyAsLong (a , b ));
96101    }
97102
98-     void  testBasicIpByteBinImpl (IpByteBin  ipByteBinFunc ) {
99-         assertEquals (15L , ipByteBinFunc .apply (new  byte [] { 1 , 1 , 1 , 1  }, new  byte [] { 1  }));
100-         assertEquals (30L , ipByteBinFunc .apply (new  byte [] { 1 , 2 , 1 , 2 , 1 , 2 , 1 , 2  }, new  byte [] { 1 , 2  }));
103+     void  testBasicIpByteBinImpl (ToLongBiFunction < byte [],  byte []>  ipByteBinFunc ) {
104+         assertEquals (15L , ipByteBinFunc .applyAsLong (new  byte [] { 1 , 1 , 1 , 1  }, new  byte [] { 1  }));
105+         assertEquals (30L , ipByteBinFunc .applyAsLong (new  byte [] { 1 , 2 , 1 , 2 , 1 , 2 , 1 , 2  }, new  byte [] { 1 , 2  }));
101106
102107        var  d  = new  byte [] { 1 , 2 , 3  };
103108        var  q  = new  byte [] { 1 , 2 , 3 , 1 , 2 , 3 , 1 , 2 , 3 , 1 , 2 , 3  };
104109        assert  scalarIpByteBin (q , d ) == 60L ; // 4 + 8 + 16 + 32 
105-         assertEquals (60L , ipByteBinFunc .apply (q , d ));
110+         assertEquals (60L , ipByteBinFunc .applyAsLong (q , d ));
106111
107112        d  = new  byte [] { 1 , 2 , 3 , 4  };
108113        q  = new  byte [] { 1 , 2 , 3 , 4 , 1 , 2 , 3 , 4 , 1 , 2 , 3 , 4 , 1 , 2 , 3 , 4  };
109114        assert  scalarIpByteBin (q , d ) == 75L ; // 5 + 10 + 20 + 40 
110-         assertEquals (75L , ipByteBinFunc .apply (q , d ));
115+         assertEquals (75L , ipByteBinFunc .applyAsLong (q , d ));
111116
112117        d  = new  byte [] { 1 , 2 , 3 , 4 , 5  };
113118        q  = new  byte [] { 1 , 2 , 3 , 4 , 5 , 1 , 2 , 3 , 4 , 5 , 1 , 2 , 3 , 4 , 5 , 1 , 2 , 3 , 4 , 5  };
114119        assert  scalarIpByteBin (q , d ) == 105L ; // 7 + 14 + 28 + 56 
115-         assertEquals (105L , ipByteBinFunc .apply (q , d ));
120+         assertEquals (105L , ipByteBinFunc .applyAsLong (q , d ));
116121
117122        d  = new  byte [] { 1 , 2 , 3 , 4 , 5 , 6  };
118123        q  = new  byte [] { 1 , 2 , 3 , 4 , 5 , 6 , 1 , 2 , 3 , 4 , 5 , 6 , 1 , 2 , 3 , 4 , 5 , 6 , 1 , 2 , 3 , 4 , 5 , 6  };
119124        assert  scalarIpByteBin (q , d ) == 135L ; // 9 + 18 + 36 + 72 
120-         assertEquals (135L , ipByteBinFunc .apply (q , d ));
125+         assertEquals (135L , ipByteBinFunc .applyAsLong (q , d ));
121126
122127        d  = new  byte [] { 1 , 2 , 3 , 4 , 5 , 6 , 7  };
123128        q  = new  byte [] { 1 , 2 , 3 , 4 , 5 , 6 , 7 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 1 , 2 , 3 , 4 , 5 , 6 , 7  };
124129        assert  scalarIpByteBin (q , d ) == 180L ; // 12 + 24 + 48 + 96 
125-         assertEquals (180L , ipByteBinFunc .apply (q , d ));
130+         assertEquals (180L , ipByteBinFunc .applyAsLong (q , d ));
126131
127132        d  = new  byte [] { 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8  };
128133        q  = new  byte [] { 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8  };
129134        assert  scalarIpByteBin (q , d ) == 195L ; // 13 + 26 + 52 + 104 
130-         assertEquals (195L , ipByteBinFunc .apply (q , d ));
135+         assertEquals (195L , ipByteBinFunc .applyAsLong (q , d ));
131136
132137        d  = new  byte [] { 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9  };
133138        q  = new  byte [] { 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9  };
134139        assert  scalarIpByteBin (q , d ) == 225L ; // 15 + 30 + 60 + 120 
135-         assertEquals (225L , ipByteBinFunc .apply (q , d ));
140+         assertEquals (225L , ipByteBinFunc .applyAsLong (q , d ));
136141    }
137142
138143    public  void  testIpByteBin () {
@@ -141,23 +146,23 @@ public void testIpByteBin() {
141146        testIpByteBinImpl (defOrPanamaProvider .getVectorUtilSupport ()::ipByteBinByte );
142147    }
143148
144-     void  testIpByteBinImpl (IpByteBin  ipByteBinFunc ) {
149+     void  testIpByteBinImpl (ToLongBiFunction < byte [],  byte []>  ipByteBinFunc ) {
145150        int  iterations  = atLeast (50 );
146151        for  (int  i  = 0 ; i  < iterations ; i ++) {
147152            int  size  = random ().nextInt (5000 );
148153            var  d  = new  byte [size ];
149154            var  q  = new  byte [size  * B_QUERY ];
150155            random ().nextBytes (d );
151156            random ().nextBytes (q );
152-             assertEquals (scalarIpByteBin (q , d ), ipByteBinFunc .apply (q , d ));
157+             assertEquals (scalarIpByteBin (q , d ), ipByteBinFunc .applyAsLong (q , d ));
153158
154159            Arrays .fill (d , Byte .MAX_VALUE );
155160            Arrays .fill (q , Byte .MAX_VALUE );
156-             assertEquals (scalarIpByteBin (q , d ), ipByteBinFunc .apply (q , d ));
161+             assertEquals (scalarIpByteBin (q , d ), ipByteBinFunc .applyAsLong (q , d ));
157162
158163            Arrays .fill (d , Byte .MIN_VALUE );
159164            Arrays .fill (q , Byte .MIN_VALUE );
160-             assertEquals (scalarIpByteBin (q , d ), ipByteBinFunc .apply (q , d ));
165+             assertEquals (scalarIpByteBin (q , d ), ipByteBinFunc .applyAsLong (q , d ));
161166        }
162167    }
163168
0 commit comments