1313import jdk .incubator .vector .FloatVector ;
1414import jdk .incubator .vector .IntVector ;
1515import jdk .incubator .vector .LongVector ;
16+ import jdk .incubator .vector .Vector ;
1617import jdk .incubator .vector .VectorMask ;
1718import jdk .incubator .vector .VectorOperators ;
1819import jdk .incubator .vector .VectorShape ;
1920import jdk .incubator .vector .VectorSpecies ;
2021
22+ import org .apache .lucene .util .BitUtil ;
2123import org .apache .lucene .util .Constants ;
2224
2325public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
@@ -52,6 +54,13 @@ public long ipByteBinByte(byte[] q, byte[] d) {
5254
5355 @ Override
5456 public int ipByteBit (byte [] q , byte [] d ) {
57+ if (d .length >= 16 && HAS_FAST_INTEGER_VECTORS ) {
58+ if (VECTOR_BITSIZE >= 512 ) {
59+ return ipByteBit512 (q , d );
60+ } else if (VECTOR_BITSIZE == 256 ) {
61+ return ipByteBit256 (q , d );
62+ }
63+ }
5564 return DefaultESVectorUtilSupport .ipByteBitImpl (q , d );
5665 }
5766
@@ -175,25 +184,71 @@ public static long ipByteBin128(byte[] q, byte[] d) {
175184 return subRet0 + (subRet1 << 1 ) + (subRet2 << 2 ) + (subRet3 << 3 );
176185 }
177186
178- private static final VectorSpecies <Float > FLOAT_SPECIES_8 = FloatVector .SPECIES_256 ;
179- private static final VectorSpecies <Float > FLOAT_SPECIES_16 = FloatVector .SPECIES_512 ;
187+ private static final VectorSpecies <Integer > INT_SPECIES_512 = IntVector .SPECIES_512 ;
188+ private static final VectorSpecies <Byte > BYTE_SPECIES_FOR_INT_512 = VectorSpecies .of (
189+ byte .class ,
190+ VectorShape .forBitSize (INT_SPECIES_512 .vectorBitSize () / Integer .BYTES )
191+ );
192+ private static final VectorSpecies <Integer > INT_SPECIES_256 = IntVector .SPECIES_256 ;
193+ private static final VectorSpecies <Byte > BYTE_SPECIES_FOR_INT_256 = VectorSpecies .of (
194+ byte .class ,
195+ VectorShape .forBitSize (INT_SPECIES_256 .vectorBitSize () / Integer .BYTES )
196+ );
197+
198+ static int ipByteBit512 (byte [] q , byte [] d ) {
199+ assert q .length == d .length * Byte .SIZE ;
200+ IntVector acc = IntVector .zero (INT_SPECIES_512 );
201+
202+ int i = 0 ;
203+ for (; i < BYTE_SPECIES_FOR_INT_512 .loopBound (q .length ); i += BYTE_SPECIES_FOR_INT_512 .length ()) {
204+ Vector <Integer > bytes = ByteVector .fromArray (BYTE_SPECIES_FOR_INT_512 , q , i ).castShape (INT_SPECIES_512 , 0 );
205+ long maskBits = Integer .reverse ((short ) BitUtil .VH_BE_SHORT .get (d , i / 8 )) >> 16 ;
206+
207+ acc = acc .add (bytes , VectorMask .fromLong (INT_SPECIES_512 , maskBits ));
208+ }
209+
210+ int sum = acc .reduceLanes (VectorOperators .ADD );
211+ if (i < q .length ) {
212+ // do the tail
213+ sum += DefaultESVectorUtilSupport .ipByteBitImpl (q , d , i );
214+ }
215+ return sum ;
216+ }
217+
218+ static int ipByteBit256 (byte [] q , byte [] d ) {
219+ assert q .length == d .length * Byte .SIZE ;
220+ IntVector acc = IntVector .zero (INT_SPECIES_256 );
221+
222+ int i = 0 ;
223+ for (; i < BYTE_SPECIES_FOR_INT_256 .loopBound (q .length ); i += BYTE_SPECIES_FOR_INT_256 .length ()) {
224+ Vector <Integer > bytes = ByteVector .fromArray (BYTE_SPECIES_FOR_INT_256 , q , i ).castShape (INT_SPECIES_256 , 0 );
225+ long maskBits = Integer .reverse (d [i / 8 ]) >> 24 ;
226+
227+ acc = acc .add (bytes , VectorMask .fromLong (INT_SPECIES_256 , maskBits ));
228+ }
180229
181- private static long reverse (byte b ) {
182- // see https://graphics.stanford.edu/~seander/bithacks.html#ReverseByteWith64Bits
183- return ((((b & 0xff ) * 0x80200802L ) & 0x0884422110L ) * 0x0101010101L >> 32 ) & 0xff ;
230+ int sum = acc .reduceLanes (VectorOperators .ADD );
231+ if (i < q .length ) {
232+ // do the tail
233+ sum += DefaultESVectorUtilSupport .ipByteBitImpl (q , d , i );
234+ }
235+ return sum ;
184236 }
185237
238+ private static final VectorSpecies <Float > FLOAT_SPECIES_512 = FloatVector .SPECIES_512 ;
239+ private static final VectorSpecies <Float > FLOAT_SPECIES_256 = FloatVector .SPECIES_256 ;
240+
186241 static float ipFloatBit512 (float [] q , byte [] d ) {
187242 assert q .length == d .length * Byte .SIZE ;
188- FloatVector acc = FloatVector .zero (FLOAT_SPECIES_16 );
243+ FloatVector acc = FloatVector .zero (FLOAT_SPECIES_512 );
189244
190245 int i = 0 ;
191- for (; i < FLOAT_SPECIES_16 .loopBound (q .length ); i += FLOAT_SPECIES_16 .length ()) {
192- FloatVector floats = FloatVector .fromArray (FLOAT_SPECIES_16 , q , i );
246+ for (; i < FLOAT_SPECIES_512 .loopBound (q .length ); i += FLOAT_SPECIES_512 .length ()) {
247+ FloatVector floats = FloatVector .fromArray (FLOAT_SPECIES_512 , q , i );
193248 // use the two bytes corresponding to the same sections
194249 // of the bit vector as a mask for addition
195- long maskBits = reverse (d [ i / 8 ]) | reverse ( d [ i / 8 + 1 ]) << 8 ;
196- acc = acc .add (floats , VectorMask .fromLong (FLOAT_SPECIES_16 , maskBits ));
250+ long maskBits = Integer . reverse (( short ) BitUtil . VH_BE_SHORT . get ( d , i / 8 )) >> 16 ;
251+ acc = acc .add (floats , VectorMask .fromLong (FLOAT_SPECIES_512 , maskBits ));
197252 }
198253
199254 float sum = acc .reduceLanes (VectorOperators .ADD );
@@ -207,15 +262,15 @@ static float ipFloatBit512(float[] q, byte[] d) {
207262
208263 static float ipFloatBit256 (float [] q , byte [] d ) {
209264 assert q .length == d .length * Byte .SIZE ;
210- FloatVector acc = FloatVector .zero (FLOAT_SPECIES_8 );
265+ FloatVector acc = FloatVector .zero (FLOAT_SPECIES_256 );
211266
212267 int i = 0 ;
213- for (; i < FLOAT_SPECIES_8 .loopBound (q .length ); i += FLOAT_SPECIES_8 .length ()) {
214- FloatVector floats = FloatVector .fromArray (FLOAT_SPECIES_8 , q , i );
268+ for (; i < FLOAT_SPECIES_256 .loopBound (q .length ); i += FLOAT_SPECIES_256 .length ()) {
269+ FloatVector floats = FloatVector .fromArray (FLOAT_SPECIES_256 , q , i );
215270 // use the byte corresponding to the same section
216271 // of the bit vector as a mask for addition
217- long maskBits = reverse (d [i / 8 ]);
218- acc = acc .add (floats , VectorMask .fromLong (FLOAT_SPECIES_8 , maskBits ));
272+ long maskBits = Integer . reverse (d [i / 8 ]) >> 24 ;
273+ acc = acc .add (floats , VectorMask .fromLong (FLOAT_SPECIES_256 , maskBits ));
219274 }
220275
221276 float sum = acc .reduceLanes (VectorOperators .ADD );
0 commit comments