@@ -309,45 +309,99 @@ private float squareDistanceBody(float[] a, float[] b, int limit) {
309309 // We also support 128 bit vectors, going 32 bits at a time.
310310 // This is slower but still faster than not vectorizing at all.
311311
312+ private interface ByteVectorLoader {
313+ int length ();
314+
315+ ByteVector load (VectorSpecies <Byte > species , int index );
316+
317+ byte tail (int index );
318+ }
319+
320+ private record ArrayLoader (byte [] arr ) implements ByteVectorLoader {
321+ @ Override
322+ public int length () {
323+ return arr .length ;
324+ }
325+
326+ @ Override
327+ public ByteVector load (VectorSpecies <Byte > species , int index ) {
328+ assert index + species .length () < length ();
329+ return ByteVector .fromArray (species , arr , index );
330+ }
331+
332+ @ Override
333+ public byte tail (int index ) {
334+ assert index < length ();
335+ return arr [index ];
336+ }
337+ }
338+
339+ private record MemorySegmentLoader (MemorySegment segment ) implements ByteVectorLoader {
340+ @ Override
341+ public int length () {
342+ return Math .toIntExact (segment .byteSize ());
343+ }
344+
345+ @ Override
346+ public ByteVector load (VectorSpecies <Byte > species , int index ) {
347+ assert index + species .length () < length ();
348+ return ByteVector .fromMemorySegment (species , segment , index , LITTLE_ENDIAN );
349+ }
350+
351+ @ Override
352+ public byte tail (int index ) {
353+ assert index < length ();
354+ return segment .get (JAVA_BYTE , index );
355+ }
356+ }
357+
312358 @ Override
313359 public int dotProduct (byte [] a , byte [] b ) {
314- return dotProduct (MemorySegment .ofArray (a ), MemorySegment .ofArray (b ));
360+ return dotProductBody (new ArrayLoader (a ), new ArrayLoader (b ));
361+ }
362+
363+ public static int dotProduct (byte [] a , MemorySegment b ) {
364+ return dotProductBody (new ArrayLoader (a ), new MemorySegmentLoader (b ));
315365 }
316366
317367 public static int dotProduct (MemorySegment a , MemorySegment b ) {
318- assert a .byteSize () == b .byteSize ();
368+ return dotProductBody (new MemorySegmentLoader (a ), new MemorySegmentLoader (b ));
369+ }
370+
371+ private static int dotProductBody (ByteVectorLoader a , ByteVectorLoader b ) {
372+ assert a .length () == b .length ();
319373 int i = 0 ;
320374 int res = 0 ;
321375
322376 // only vectorize if we'll at least enter the loop a single time
323- if (a .byteSize () >= 16 ) {
377+ if (a .length () >= 16 ) {
324378 // compute vectorized dot product consistent with VPDPBUSD instruction
325379 if (VECTOR_BITSIZE >= 512 ) {
326- i += BYTE_SPECIES .loopBound (a .byteSize ());
380+ i += BYTE_SPECIES .loopBound (a .length ());
327381 res += dotProductBody512 (a , b , i );
328382 } else if (VECTOR_BITSIZE == 256 ) {
329- i += BYTE_SPECIES .loopBound (a .byteSize ());
383+ i += BYTE_SPECIES .loopBound (a .length ());
330384 res += dotProductBody256 (a , b , i );
331385 } else {
332386 // tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
333- i += ByteVector .SPECIES_64 .loopBound (a .byteSize () - ByteVector .SPECIES_64 .length ());
387+ i += ByteVector .SPECIES_64 .loopBound (a .length () - ByteVector .SPECIES_64 .length ());
334388 res += dotProductBody128 (a , b , i );
335389 }
336390 }
337391
338392 // scalar tail
339- for (; i < a .byteSize (); i ++) {
340- res += b . get ( JAVA_BYTE , i ) * a . get ( JAVA_BYTE , i );
393+ for (; i < a .length (); i ++) {
394+ res += a . tail ( i ) * b . tail ( i );
341395 }
342396 return res ;
343397 }
344398
345399 /** vectorized dot product body (512 bit vectors) */
346- private static int dotProductBody512 (MemorySegment a , MemorySegment b , int limit ) {
400+ private static int dotProductBody512 (ByteVectorLoader a , ByteVectorLoader b , int limit ) {
347401 IntVector acc = IntVector .zero (INT_SPECIES );
348402 for (int i = 0 ; i < limit ; i += BYTE_SPECIES .length ()) {
349- ByteVector va8 = ByteVector . fromMemorySegment (BYTE_SPECIES , a , i , LITTLE_ENDIAN );
350- ByteVector vb8 = ByteVector . fromMemorySegment (BYTE_SPECIES , b , i , LITTLE_ENDIAN );
403+ ByteVector va8 = a . load (BYTE_SPECIES , i );
404+ ByteVector vb8 = b . load (BYTE_SPECIES , i );
351405
352406 // 16-bit multiply: avoid AVX-512 heavy multiply on zmm
353407 Vector <Short > va16 = va8 .convertShape (B2S , SHORT_SPECIES , 0 );
@@ -363,11 +417,11 @@ private static int dotProductBody512(MemorySegment a, MemorySegment b, int limit
363417 }
364418
365419 /** vectorized dot product body (256 bit vectors) */
366- private static int dotProductBody256 (MemorySegment a , MemorySegment b , int limit ) {
420+ private static int dotProductBody256 (ByteVectorLoader a , ByteVectorLoader b , int limit ) {
367421 IntVector acc = IntVector .zero (IntVector .SPECIES_256 );
368422 for (int i = 0 ; i < limit ; i += ByteVector .SPECIES_64 .length ()) {
369- ByteVector va8 = ByteVector . fromMemorySegment (ByteVector .SPECIES_64 , a , i , LITTLE_ENDIAN );
370- ByteVector vb8 = ByteVector . fromMemorySegment (ByteVector .SPECIES_64 , b , i , LITTLE_ENDIAN );
423+ ByteVector va8 = a . load (ByteVector .SPECIES_64 , i );
424+ ByteVector vb8 = b . load (ByteVector .SPECIES_64 , i );
371425
372426 // 32-bit multiply and add into accumulator
373427 Vector <Integer > va32 = va8 .convertShape (B2I , IntVector .SPECIES_256 , 0 );
@@ -379,13 +433,13 @@ private static int dotProductBody256(MemorySegment a, MemorySegment b, int limit
379433 }
380434
381435 /** vectorized dot product body (128 bit vectors) */
382- private static int dotProductBody128 (MemorySegment a , MemorySegment b , int limit ) {
436+ private static int dotProductBody128 (ByteVectorLoader a , ByteVectorLoader b , int limit ) {
383437 IntVector acc = IntVector .zero (IntVector .SPECIES_128 );
384438 // 4 bytes at a time (re-loading half the vector each time!)
385439 for (int i = 0 ; i < limit ; i += ByteVector .SPECIES_64 .length () >> 1 ) {
386440 // load 8 bytes
387- ByteVector va8 = ByteVector . fromMemorySegment (ByteVector .SPECIES_64 , a , i , LITTLE_ENDIAN );
388- ByteVector vb8 = ByteVector . fromMemorySegment (ByteVector .SPECIES_64 , b , i , LITTLE_ENDIAN );
441+ ByteVector va8 = a . load (ByteVector .SPECIES_64 , i );
442+ ByteVector vb8 = b . load (ByteVector .SPECIES_64 , i );
389443
390444 // process first "half" only: 16-bit multiply
391445 Vector <Short > va16 = va8 .convert (B2S , 0 );
@@ -577,27 +631,35 @@ private int int4DotProductBody128(byte[] a, byte[] b, int limit) {
577631
578632 @ Override
579633 public float cosine (byte [] a , byte [] b ) {
580- return cosine ( MemorySegment . ofArray (a ), MemorySegment . ofArray (b ));
634+ return cosineBody ( new ArrayLoader (a ), new ArrayLoader (b ));
581635 }
582636
583637 public static float cosine (MemorySegment a , MemorySegment b ) {
638+ return cosineBody (new MemorySegmentLoader (a ), new MemorySegmentLoader (b ));
639+ }
640+
641+ public static float cosine (byte [] a , MemorySegment b ) {
642+ return cosineBody (new ArrayLoader (a ), new MemorySegmentLoader (b ));
643+ }
644+
645+ private static float cosineBody (ByteVectorLoader a , ByteVectorLoader b ) {
584646 int i = 0 ;
585647 int sum = 0 ;
586648 int norm1 = 0 ;
587649 int norm2 = 0 ;
588650
589651 // only vectorize if we'll at least enter the loop a single time
590- if (a .byteSize () >= 16 ) {
652+ if (a .length () >= 16 ) {
591653 final float [] ret ;
592654 if (VECTOR_BITSIZE >= 512 ) {
593- i += BYTE_SPECIES .loopBound (( int ) a . byteSize ());
655+ i += BYTE_SPECIES .loopBound (a . length ());
594656 ret = cosineBody512 (a , b , i );
595657 } else if (VECTOR_BITSIZE == 256 ) {
596- i += BYTE_SPECIES .loopBound (( int ) a . byteSize ());
658+ i += BYTE_SPECIES .loopBound (a . length ());
597659 ret = cosineBody256 (a , b , i );
598660 } else {
599661 // tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
600- i += ByteVector .SPECIES_64 .loopBound (a .byteSize () - ByteVector .SPECIES_64 .length ());
662+ i += ByteVector .SPECIES_64 .loopBound (a .length () - ByteVector .SPECIES_64 .length ());
601663 ret = cosineBody128 (a , b , i );
602664 }
603665 sum += ret [0 ];
@@ -606,9 +668,9 @@ public static float cosine(MemorySegment a, MemorySegment b) {
606668 }
607669
608670 // scalar tail
609- for (; i < a .byteSize (); i ++) {
610- byte elem1 = a .get ( JAVA_BYTE , i );
611- byte elem2 = b .get ( JAVA_BYTE , i );
671+ for (; i < a .length (); i ++) {
672+ byte elem1 = a .tail ( i );
673+ byte elem2 = b .tail ( i );
612674 sum += elem1 * elem2 ;
613675 norm1 += elem1 * elem1 ;
614676 norm2 += elem2 * elem2 ;
@@ -617,13 +679,13 @@ public static float cosine(MemorySegment a, MemorySegment b) {
617679 }
618680
619681 /** vectorized cosine body (512 bit vectors) */
620- private static float [] cosineBody512 (MemorySegment a , MemorySegment b , int limit ) {
682+ private static float [] cosineBody512 (ByteVectorLoader a , ByteVectorLoader b , int limit ) {
621683 IntVector accSum = IntVector .zero (INT_SPECIES );
622684 IntVector accNorm1 = IntVector .zero (INT_SPECIES );
623685 IntVector accNorm2 = IntVector .zero (INT_SPECIES );
624686 for (int i = 0 ; i < limit ; i += BYTE_SPECIES .length ()) {
625- ByteVector va8 = ByteVector . fromMemorySegment (BYTE_SPECIES , a , i , LITTLE_ENDIAN );
626- ByteVector vb8 = ByteVector . fromMemorySegment (BYTE_SPECIES , b , i , LITTLE_ENDIAN );
687+ ByteVector va8 = a . load (BYTE_SPECIES , i );
688+ ByteVector vb8 = b . load (BYTE_SPECIES , i );
627689
628690 // 16-bit multiply: avoid AVX-512 heavy multiply on zmm
629691 Vector <Short > va16 = va8 .convertShape (B2S , SHORT_SPECIES , 0 );
@@ -647,13 +709,13 @@ private static float[] cosineBody512(MemorySegment a, MemorySegment b, int limit
647709 }
648710
649711 /** vectorized cosine body (256 bit vectors) */
650- private static float [] cosineBody256 (MemorySegment a , MemorySegment b , int limit ) {
712+ private static float [] cosineBody256 (ByteVectorLoader a , ByteVectorLoader b , int limit ) {
651713 IntVector accSum = IntVector .zero (IntVector .SPECIES_256 );
652714 IntVector accNorm1 = IntVector .zero (IntVector .SPECIES_256 );
653715 IntVector accNorm2 = IntVector .zero (IntVector .SPECIES_256 );
654716 for (int i = 0 ; i < limit ; i += ByteVector .SPECIES_64 .length ()) {
655- ByteVector va8 = ByteVector . fromMemorySegment (ByteVector .SPECIES_64 , a , i , LITTLE_ENDIAN );
656- ByteVector vb8 = ByteVector . fromMemorySegment (ByteVector .SPECIES_64 , b , i , LITTLE_ENDIAN );
717+ ByteVector va8 = a . load (ByteVector .SPECIES_64 , i );
718+ ByteVector vb8 = b . load (ByteVector .SPECIES_64 , i );
657719
658720 // 16-bit multiply, and add into accumulators
659721 Vector <Integer > va32 = va8 .convertShape (B2I , IntVector .SPECIES_256 , 0 );
@@ -672,13 +734,13 @@ private static float[] cosineBody256(MemorySegment a, MemorySegment b, int limit
672734 }
673735
674736 /** vectorized cosine body (128 bit vectors) */
675- private static float [] cosineBody128 (MemorySegment a , MemorySegment b , int limit ) {
737+ private static float [] cosineBody128 (ByteVectorLoader a , ByteVectorLoader b , int limit ) {
676738 IntVector accSum = IntVector .zero (IntVector .SPECIES_128 );
677739 IntVector accNorm1 = IntVector .zero (IntVector .SPECIES_128 );
678740 IntVector accNorm2 = IntVector .zero (IntVector .SPECIES_128 );
679741 for (int i = 0 ; i < limit ; i += ByteVector .SPECIES_64 .length () >> 1 ) {
680- ByteVector va8 = ByteVector . fromMemorySegment (ByteVector .SPECIES_64 , a , i , LITTLE_ENDIAN );
681- ByteVector vb8 = ByteVector . fromMemorySegment (ByteVector .SPECIES_64 , b , i , LITTLE_ENDIAN );
742+ ByteVector va8 = a . load (ByteVector .SPECIES_64 , i );
743+ ByteVector vb8 = b . load (ByteVector .SPECIES_64 , i );
682744
683745 // process first half only: 16-bit multiply
684746 Vector <Short > va16 = va8 .convert (B2S , 0 );
@@ -700,39 +762,47 @@ private static float[] cosineBody128(MemorySegment a, MemorySegment b, int limit
700762
701763 @ Override
702764 public int squareDistance (byte [] a , byte [] b ) {
703- return squareDistance ( MemorySegment . ofArray (a ), MemorySegment . ofArray (b ));
765+ return squareDistanceBody ( new ArrayLoader (a ), new ArrayLoader (b ));
704766 }
705767
706768 public static int squareDistance (MemorySegment a , MemorySegment b ) {
707- assert a .byteSize () == b .byteSize ();
769+ return squareDistanceBody (new MemorySegmentLoader (a ), new MemorySegmentLoader (b ));
770+ }
771+
772+ public static int squareDistance (byte [] a , MemorySegment b ) {
773+ return squareDistanceBody (new ArrayLoader (a ), new MemorySegmentLoader (b ));
774+ }
775+
776+ private static int squareDistanceBody (ByteVectorLoader a , ByteVectorLoader b ) {
777+ assert a .length () == b .length ();
708778 int i = 0 ;
709779 int res = 0 ;
710780
711781 // only vectorize if we'll at least enter the loop a single time
712- if (a .byteSize () >= 16 ) {
782+ if (a .length () >= 16 ) {
713783 if (VECTOR_BITSIZE >= 256 ) {
714- i += BYTE_SPECIES .loopBound (( int ) a . byteSize ());
784+ i += BYTE_SPECIES .loopBound (a . length ());
715785 res += squareDistanceBody256 (a , b , i );
716786 } else {
717- i += ByteVector .SPECIES_64 .loopBound (( int ) a . byteSize ());
787+ i += ByteVector .SPECIES_64 .loopBound (a . length ());
718788 res += squareDistanceBody128 (a , b , i );
719789 }
720790 }
721791
722792 // scalar tail
723- for (; i < a .byteSize (); i ++) {
724- int diff = a .get ( JAVA_BYTE , i ) - b .get ( JAVA_BYTE , i );
793+ for (; i < a .length (); i ++) {
794+ int diff = a .tail ( i ) - b .tail ( i );
725795 res += diff * diff ;
726796 }
727797 return res ;
728798 }
729799
730800 /** vectorized square distance body (256+ bit vectors) */
731- private static int squareDistanceBody256 (MemorySegment a , MemorySegment b , int limit ) {
801+ private static int squareDistanceBody256 (ByteVectorLoader a , ByteVectorLoader b , int limit ) {
732802 IntVector acc = IntVector .zero (INT_SPECIES );
733803 for (int i = 0 ; i < limit ; i += BYTE_SPECIES .length ()) {
734- ByteVector va8 = ByteVector . fromMemorySegment (BYTE_SPECIES , a , i , LITTLE_ENDIAN );
735- ByteVector vb8 = ByteVector . fromMemorySegment (BYTE_SPECIES , b , i , LITTLE_ENDIAN );
804+ ByteVector va8 = a . load (BYTE_SPECIES , i );
805+ ByteVector vb8 = b . load (BYTE_SPECIES , i );
736806
737807 // 32-bit sub, multiply, and add into accumulators
738808 // TODO: uses AVX-512 heavy multiply on zmm, should we just use 256-bit vectors on AVX-512?
@@ -746,14 +816,14 @@ private static int squareDistanceBody256(MemorySegment a, MemorySegment b, int l
746816 }
747817
748818 /** vectorized square distance body (128 bit vectors) */
749- private static int squareDistanceBody128 (MemorySegment a , MemorySegment b , int limit ) {
819+ private static int squareDistanceBody128 (ByteVectorLoader a , ByteVectorLoader b , int limit ) {
750820 // 128-bit implementation, which must "split up" vectors due to widening conversions
751821 // it doesn't help to do the overlapping read trick, due to 32-bit multiply in the formula
752822 IntVector acc1 = IntVector .zero (IntVector .SPECIES_128 );
753823 IntVector acc2 = IntVector .zero (IntVector .SPECIES_128 );
754824 for (int i = 0 ; i < limit ; i += ByteVector .SPECIES_64 .length ()) {
755- ByteVector va8 = ByteVector . fromMemorySegment (ByteVector .SPECIES_64 , a , i , LITTLE_ENDIAN );
756- ByteVector vb8 = ByteVector . fromMemorySegment (ByteVector .SPECIES_64 , b , i , LITTLE_ENDIAN );
825+ ByteVector va8 = a . load (ByteVector .SPECIES_64 , i );
826+ ByteVector vb8 = b . load (ByteVector .SPECIES_64 , i );
757827
758828 // 16-bit sub
759829 Vector <Short > va16 = va8 .convertShape (B2S , ShortVector .SPECIES_128 , 0 );
0 commit comments