@@ -309,46 +309,100 @@ private float squareDistanceBody(float[] a, float[] b, int limit) {
309309 // We also support 128 bit vectors, going 32 bits at a time.
310310 // This is slower but still faster than not vectorizing at all.
311311
312+ private interface ByteVectorLoader {
313+ int length ();
314+
315+ ByteVector load (VectorSpecies <Byte > species , int index );
316+
317+ byte tail (int index );
318+ }
319+
320+ private record ArrayLoader (byte [] arr ) implements ByteVectorLoader {
321+ @ Override
322+ public int length () {
323+ return arr .length ;
324+ }
325+
326+ @ Override
327+ public ByteVector load (VectorSpecies <Byte > species , int index ) {
328+ assert index + species .length () <= length ();
329+ return ByteVector .fromArray (species , arr , index );
330+ }
331+
332+ @ Override
333+ public byte tail (int index ) {
334+ assert index <= length ();
335+ return arr [index ];
336+ }
337+ }
338+
339+ private record MemorySegmentLoader (MemorySegment segment ) implements ByteVectorLoader {
340+ @ Override
341+ public int length () {
342+ return Math .toIntExact (segment .byteSize ());
343+ }
344+
345+ @ Override
346+ public ByteVector load (VectorSpecies <Byte > species , int index ) {
347+ assert index + species .length () <= length ();
348+ return ByteVector .fromMemorySegment (species , segment , index , LITTLE_ENDIAN );
349+ }
350+
351+ @ Override
352+ public byte tail (int index ) {
353+ assert index <= length ();
354+ return segment .get (JAVA_BYTE , index );
355+ }
356+ }
357+
312358 @ Override
313359 public int dotProduct (byte [] a , byte [] b ) {
314- return dotProduct (MemorySegment .ofArray (a ), MemorySegment .ofArray (b ));
360+ return dotProductBody (new ArrayLoader (a ), new ArrayLoader (b ));
361+ }
362+
363+ public static int dotProduct (byte [] a , MemorySegment b ) {
364+ return dotProductBody (new ArrayLoader (a ), new MemorySegmentLoader (b ));
315365 }
316366
317367 public static int dotProduct (MemorySegment a , MemorySegment b ) {
318- assert a .byteSize () == b .byteSize ();
368+ return dotProductBody (new MemorySegmentLoader (a ), new MemorySegmentLoader (b ));
369+ }
370+
371+ private static int dotProductBody (ByteVectorLoader a , ByteVectorLoader b ) {
372+ assert a .length () == b .length ();
319373 int i = 0 ;
320374 int res = 0 ;
321375
322376 // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit
323377 // vectors (256-bit on intel to dodge performance landmines)
324- if (a .byteSize () >= 16 && PanamaVectorConstants .HAS_FAST_INTEGER_VECTORS ) {
378+ if (a .length () >= 16 && PanamaVectorConstants .HAS_FAST_INTEGER_VECTORS ) {
325379 // compute vectorized dot product consistent with VPDPBUSD instruction
326380 if (VECTOR_BITSIZE >= 512 ) {
327- i += BYTE_SPECIES .loopBound (a .byteSize ());
381+ i += BYTE_SPECIES .loopBound (a .length ());
328382 res += dotProductBody512 (a , b , i );
329383 } else if (VECTOR_BITSIZE == 256 ) {
330- i += BYTE_SPECIES .loopBound (a .byteSize ());
384+ i += BYTE_SPECIES .loopBound (a .length ());
331385 res += dotProductBody256 (a , b , i );
332386 } else {
333387 // tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
334- i += ByteVector .SPECIES_64 .loopBound (a .byteSize () - ByteVector .SPECIES_64 .length ());
388+ i += ByteVector .SPECIES_64 .loopBound (a .length () - ByteVector .SPECIES_64 .length ());
335389 res += dotProductBody128 (a , b , i );
336390 }
337391 }
338392
339393 // scalar tail
340- for (; i < a .byteSize (); i ++) {
341- res += b . get ( JAVA_BYTE , i ) * a . get ( JAVA_BYTE , i );
394+ for (; i < a .length (); i ++) {
395+ res += a . tail ( i ) * b . tail ( i );
342396 }
343397 return res ;
344398 }
345399
346400 /** vectorized dot product body (512 bit vectors) */
347- private static int dotProductBody512 (MemorySegment a , MemorySegment b , int limit ) {
401+ private static int dotProductBody512 (ByteVectorLoader a , ByteVectorLoader b , int limit ) {
348402 IntVector acc = IntVector .zero (INT_SPECIES );
349403 for (int i = 0 ; i < limit ; i += BYTE_SPECIES .length ()) {
350- ByteVector va8 = ByteVector . fromMemorySegment (BYTE_SPECIES , a , i , LITTLE_ENDIAN );
351- ByteVector vb8 = ByteVector . fromMemorySegment (BYTE_SPECIES , b , i , LITTLE_ENDIAN );
404+ ByteVector va8 = a . load (BYTE_SPECIES , i );
405+ ByteVector vb8 = b . load (BYTE_SPECIES , i );
352406
353407 // 16-bit multiply: avoid AVX-512 heavy multiply on zmm
354408 Vector <Short > va16 = va8 .convertShape (B2S , SHORT_SPECIES , 0 );
@@ -364,11 +418,11 @@ private static int dotProductBody512(MemorySegment a, MemorySegment b, int limit
364418 }
365419
366420 /** vectorized dot product body (256 bit vectors) */
367- private static int dotProductBody256 (MemorySegment a , MemorySegment b , int limit ) {
421+ private static int dotProductBody256 (ByteVectorLoader a , ByteVectorLoader b , int limit ) {
368422 IntVector acc = IntVector .zero (IntVector .SPECIES_256 );
369423 for (int i = 0 ; i < limit ; i += ByteVector .SPECIES_64 .length ()) {
370- ByteVector va8 = ByteVector . fromMemorySegment (ByteVector .SPECIES_64 , a , i , LITTLE_ENDIAN );
371- ByteVector vb8 = ByteVector . fromMemorySegment (ByteVector .SPECIES_64 , b , i , LITTLE_ENDIAN );
424+ ByteVector va8 = a . load (ByteVector .SPECIES_64 , i );
425+ ByteVector vb8 = b . load (ByteVector .SPECIES_64 , i );
372426
373427 // 32-bit multiply and add into accumulator
374428 Vector <Integer > va32 = va8 .convertShape (B2I , IntVector .SPECIES_256 , 0 );
@@ -380,13 +434,13 @@ private static int dotProductBody256(MemorySegment a, MemorySegment b, int limit
380434 }
381435
382436 /** vectorized dot product body (128 bit vectors) */
383- private static int dotProductBody128 (MemorySegment a , MemorySegment b , int limit ) {
437+ private static int dotProductBody128 (ByteVectorLoader a , ByteVectorLoader b , int limit ) {
384438 IntVector acc = IntVector .zero (IntVector .SPECIES_128 );
385439 // 4 bytes at a time (re-loading half the vector each time!)
386440 for (int i = 0 ; i < limit ; i += ByteVector .SPECIES_64 .length () >> 1 ) {
387441 // load 8 bytes
388- ByteVector va8 = ByteVector . fromMemorySegment (ByteVector .SPECIES_64 , a , i , LITTLE_ENDIAN );
389- ByteVector vb8 = ByteVector . fromMemorySegment (ByteVector .SPECIES_64 , b , i , LITTLE_ENDIAN );
442+ ByteVector va8 = a . load (ByteVector .SPECIES_64 , i );
443+ ByteVector vb8 = b . load (ByteVector .SPECIES_64 , i );
390444
391445 // process first "half" only: 16-bit multiply
392446 Vector <Short > va16 = va8 .convert (B2S , 0 );
@@ -578,28 +632,36 @@ private int int4DotProductBody128(byte[] a, byte[] b, int limit) {
578632
579633 @ Override
580634 public float cosine (byte [] a , byte [] b ) {
581- return cosine ( MemorySegment . ofArray (a ), MemorySegment . ofArray (b ));
635+ return cosineBody ( new ArrayLoader (a ), new ArrayLoader (b ));
582636 }
583637
584638 public static float cosine (MemorySegment a , MemorySegment b ) {
639+ return cosineBody (new MemorySegmentLoader (a ), new MemorySegmentLoader (b ));
640+ }
641+
642+ public static float cosine (byte [] a , MemorySegment b ) {
643+ return cosineBody (new ArrayLoader (a ), new MemorySegmentLoader (b ));
644+ }
645+
646+ private static float cosineBody (ByteVectorLoader a , ByteVectorLoader b ) {
585647 int i = 0 ;
586648 int sum = 0 ;
587649 int norm1 = 0 ;
588650 int norm2 = 0 ;
589651
590652 // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit
591653 // vectors (256-bit on intel to dodge performance landmines)
592- if (a .byteSize () >= 16 && PanamaVectorConstants .HAS_FAST_INTEGER_VECTORS ) {
654+ if (a .length () >= 16 && PanamaVectorConstants .HAS_FAST_INTEGER_VECTORS ) {
593655 final float [] ret ;
594656 if (VECTOR_BITSIZE >= 512 ) {
595- i += BYTE_SPECIES .loopBound (( int ) a . byteSize ());
657+ i += BYTE_SPECIES .loopBound (a . length ());
596658 ret = cosineBody512 (a , b , i );
597659 } else if (VECTOR_BITSIZE == 256 ) {
598- i += BYTE_SPECIES .loopBound (( int ) a . byteSize ());
660+ i += BYTE_SPECIES .loopBound (a . length ());
599661 ret = cosineBody256 (a , b , i );
600662 } else {
601663 // tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
602- i += ByteVector .SPECIES_64 .loopBound (a .byteSize () - ByteVector .SPECIES_64 .length ());
664+ i += ByteVector .SPECIES_64 .loopBound (a .length () - ByteVector .SPECIES_64 .length ());
603665 ret = cosineBody128 (a , b , i );
604666 }
605667 sum += ret [0 ];
@@ -608,9 +670,9 @@ public static float cosine(MemorySegment a, MemorySegment b) {
608670 }
609671
610672 // scalar tail
611- for (; i < a .byteSize (); i ++) {
612- byte elem1 = a .get ( JAVA_BYTE , i );
613- byte elem2 = b .get ( JAVA_BYTE , i );
673+ for (; i < a .length (); i ++) {
674+ byte elem1 = a .tail ( i );
675+ byte elem2 = b .tail ( i );
614676 sum += elem1 * elem2 ;
615677 norm1 += elem1 * elem1 ;
616678 norm2 += elem2 * elem2 ;
@@ -619,13 +681,13 @@ public static float cosine(MemorySegment a, MemorySegment b) {
619681 }
620682
621683 /** vectorized cosine body (512 bit vectors) */
622- private static float [] cosineBody512 (MemorySegment a , MemorySegment b , int limit ) {
684+ private static float [] cosineBody512 (ByteVectorLoader a , ByteVectorLoader b , int limit ) {
623685 IntVector accSum = IntVector .zero (INT_SPECIES );
624686 IntVector accNorm1 = IntVector .zero (INT_SPECIES );
625687 IntVector accNorm2 = IntVector .zero (INT_SPECIES );
626688 for (int i = 0 ; i < limit ; i += BYTE_SPECIES .length ()) {
627- ByteVector va8 = ByteVector . fromMemorySegment (BYTE_SPECIES , a , i , LITTLE_ENDIAN );
628- ByteVector vb8 = ByteVector . fromMemorySegment (BYTE_SPECIES , b , i , LITTLE_ENDIAN );
689+ ByteVector va8 = a . load (BYTE_SPECIES , i );
690+ ByteVector vb8 = b . load (BYTE_SPECIES , i );
629691
630692 // 16-bit multiply: avoid AVX-512 heavy multiply on zmm
631693 Vector <Short > va16 = va8 .convertShape (B2S , SHORT_SPECIES , 0 );
@@ -649,13 +711,13 @@ private static float[] cosineBody512(MemorySegment a, MemorySegment b, int limit
649711 }
650712
651713 /** vectorized cosine body (256 bit vectors) */
652- private static float [] cosineBody256 (MemorySegment a , MemorySegment b , int limit ) {
714+ private static float [] cosineBody256 (ByteVectorLoader a , ByteVectorLoader b , int limit ) {
653715 IntVector accSum = IntVector .zero (IntVector .SPECIES_256 );
654716 IntVector accNorm1 = IntVector .zero (IntVector .SPECIES_256 );
655717 IntVector accNorm2 = IntVector .zero (IntVector .SPECIES_256 );
656718 for (int i = 0 ; i < limit ; i += ByteVector .SPECIES_64 .length ()) {
657- ByteVector va8 = ByteVector . fromMemorySegment (ByteVector .SPECIES_64 , a , i , LITTLE_ENDIAN );
658- ByteVector vb8 = ByteVector . fromMemorySegment (ByteVector .SPECIES_64 , b , i , LITTLE_ENDIAN );
719+ ByteVector va8 = a . load (ByteVector .SPECIES_64 , i );
720+ ByteVector vb8 = b . load (ByteVector .SPECIES_64 , i );
659721
660722 // 16-bit multiply, and add into accumulators
661723 Vector <Integer > va32 = va8 .convertShape (B2I , IntVector .SPECIES_256 , 0 );
@@ -674,13 +736,13 @@ private static float[] cosineBody256(MemorySegment a, MemorySegment b, int limit
674736 }
675737
676738 /** vectorized cosine body (128 bit vectors) */
677- private static float [] cosineBody128 (MemorySegment a , MemorySegment b , int limit ) {
739+ private static float [] cosineBody128 (ByteVectorLoader a , ByteVectorLoader b , int limit ) {
678740 IntVector accSum = IntVector .zero (IntVector .SPECIES_128 );
679741 IntVector accNorm1 = IntVector .zero (IntVector .SPECIES_128 );
680742 IntVector accNorm2 = IntVector .zero (IntVector .SPECIES_128 );
681743 for (int i = 0 ; i < limit ; i += ByteVector .SPECIES_64 .length () >> 1 ) {
682- ByteVector va8 = ByteVector . fromMemorySegment (ByteVector .SPECIES_64 , a , i , LITTLE_ENDIAN );
683- ByteVector vb8 = ByteVector . fromMemorySegment (ByteVector .SPECIES_64 , b , i , LITTLE_ENDIAN );
744+ ByteVector va8 = a . load (ByteVector .SPECIES_64 , i );
745+ ByteVector vb8 = b . load (ByteVector .SPECIES_64 , i );
684746
685747 // process first half only: 16-bit multiply
686748 Vector <Short > va16 = va8 .convert (B2S , 0 );
@@ -702,40 +764,48 @@ private static float[] cosineBody128(MemorySegment a, MemorySegment b, int limit
702764
703765 @ Override
704766 public int squareDistance (byte [] a , byte [] b ) {
705- return squareDistance ( MemorySegment . ofArray (a ), MemorySegment . ofArray (b ));
767+ return squareDistanceBody ( new ArrayLoader (a ), new ArrayLoader (b ));
706768 }
707769
708770 public static int squareDistance (MemorySegment a , MemorySegment b ) {
709- assert a .byteSize () == b .byteSize ();
771+ return squareDistanceBody (new MemorySegmentLoader (a ), new MemorySegmentLoader (b ));
772+ }
773+
774+ public static int squareDistance (byte [] a , MemorySegment b ) {
775+ return squareDistanceBody (new ArrayLoader (a ), new MemorySegmentLoader (b ));
776+ }
777+
778+ private static int squareDistanceBody (ByteVectorLoader a , ByteVectorLoader b ) {
779+ assert a .length () == b .length ();
710780 int i = 0 ;
711781 int res = 0 ;
712782
713783 // only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit
714784 // vectors (256-bit on intel to dodge performance landmines)
715- if (a .byteSize () >= 16 && PanamaVectorConstants .HAS_FAST_INTEGER_VECTORS ) {
785+ if (a .length () >= 16 && PanamaVectorConstants .HAS_FAST_INTEGER_VECTORS ) {
716786 if (VECTOR_BITSIZE >= 256 ) {
717- i += BYTE_SPECIES .loopBound (( int ) a . byteSize ());
787+ i += BYTE_SPECIES .loopBound (a . length ());
718788 res += squareDistanceBody256 (a , b , i );
719789 } else {
720- i += ByteVector .SPECIES_64 .loopBound (( int ) a . byteSize ());
790+ i += ByteVector .SPECIES_64 .loopBound (a . length ());
721791 res += squareDistanceBody128 (a , b , i );
722792 }
723793 }
724794
725795 // scalar tail
726- for (; i < a .byteSize (); i ++) {
727- int diff = a .get ( JAVA_BYTE , i ) - b .get ( JAVA_BYTE , i );
796+ for (; i < a .length (); i ++) {
797+ int diff = a .tail ( i ) - b .tail ( i );
728798 res += diff * diff ;
729799 }
730800 return res ;
731801 }
732802
733803 /** vectorized square distance body (256+ bit vectors) */
734- private static int squareDistanceBody256 (MemorySegment a , MemorySegment b , int limit ) {
804+ private static int squareDistanceBody256 (ByteVectorLoader a , ByteVectorLoader b , int limit ) {
735805 IntVector acc = IntVector .zero (INT_SPECIES );
736806 for (int i = 0 ; i < limit ; i += BYTE_SPECIES .length ()) {
737- ByteVector va8 = ByteVector . fromMemorySegment (BYTE_SPECIES , a , i , LITTLE_ENDIAN );
738- ByteVector vb8 = ByteVector . fromMemorySegment (BYTE_SPECIES , b , i , LITTLE_ENDIAN );
807+ ByteVector va8 = a . load (BYTE_SPECIES , i );
808+ ByteVector vb8 = b . load (BYTE_SPECIES , i );
739809
740810 // 32-bit sub, multiply, and add into accumulators
741811 // TODO: uses AVX-512 heavy multiply on zmm, should we just use 256-bit vectors on AVX-512?
@@ -749,14 +819,14 @@ private static int squareDistanceBody256(MemorySegment a, MemorySegment b, int l
749819 }
750820
751821 /** vectorized square distance body (128 bit vectors) */
752- private static int squareDistanceBody128 (MemorySegment a , MemorySegment b , int limit ) {
822+ private static int squareDistanceBody128 (ByteVectorLoader a , ByteVectorLoader b , int limit ) {
753823 // 128-bit implementation, which must "split up" vectors due to widening conversions
754824 // it doesn't help to do the overlapping read trick, due to 32-bit multiply in the formula
755825 IntVector acc1 = IntVector .zero (IntVector .SPECIES_128 );
756826 IntVector acc2 = IntVector .zero (IntVector .SPECIES_128 );
757827 for (int i = 0 ; i < limit ; i += ByteVector .SPECIES_64 .length ()) {
758- ByteVector va8 = ByteVector . fromMemorySegment (ByteVector .SPECIES_64 , a , i , LITTLE_ENDIAN );
759- ByteVector vb8 = ByteVector . fromMemorySegment (ByteVector .SPECIES_64 , b , i , LITTLE_ENDIAN );
828+ ByteVector va8 = a . load (ByteVector .SPECIES_64 , i );
829+ ByteVector vb8 = b . load (ByteVector .SPECIES_64 , i );
760830
761831 // 16-bit sub
762832 Vector <Short > va16 = va8 .convertShape (B2S , ShortVector .SPECIES_128 , 0 );
0 commit comments