@@ -43,11 +43,13 @@ static void compressBytes(byte[] raw, byte[] compressed) {
4343 private byte [] bytesA ;
4444 private byte [] bytesB ;
4545 private byte [] halfBytesA ;
46+ private byte [] halfBytesAPacked ;
4647 private byte [] halfBytesB ;
4748 private byte [] halfBytesBPacked ;
4849 private float [] floatsA ;
4950 private float [] floatsB ;
50- private int expectedhalfByteDotProduct ;
51+ private int expectedHalfByteDotProduct ;
52+ private int expectedHalfByteSquareDistance ;
5153
5254 @ Param ({"1" , "128" , "207" , "256" , "300" , "512" , "702" , "1024" })
5355 int size ;
@@ -63,16 +65,23 @@ public void init() {
6365 random .nextBytes (bytesB );
6466 // random half byte arrays for binary methods
6567 // this means that all values must be between 0 and 15
66- expectedhalfByteDotProduct = 0 ;
68+ expectedHalfByteDotProduct = 0 ;
69+ expectedHalfByteSquareDistance = 0 ;
6770 halfBytesA = new byte [size ];
6871 halfBytesB = new byte [size ];
6972 for (int i = 0 ; i < size ; ++i ) {
7073 halfBytesA [i ] = (byte ) random .nextInt (16 );
7174 halfBytesB [i ] = (byte ) random .nextInt (16 );
72- expectedhalfByteDotProduct += halfBytesA [i ] * halfBytesB [i ];
75+ expectedHalfByteDotProduct += halfBytesA [i ] * halfBytesB [i ];
76+
77+ int diff = halfBytesA [i ] - halfBytesB [i ];
78+ expectedHalfByteSquareDistance += diff * diff ;
7379 }
7480 // pack the half byte arrays
7581 if (size % 2 == 0 ) {
82+ halfBytesAPacked = new byte [(size + 1 ) >> 1 ];
83+ compressBytes (halfBytesA , halfBytesAPacked );
84+
7685 halfBytesBPacked = new byte [(size + 1 ) >> 1 ];
7786 compressBytes (halfBytesB , halfBytesBPacked );
7887 }
@@ -97,6 +106,74 @@ public float binaryCosineVector() {
97106 return VectorUtil .cosine (bytesA , bytesB );
98107 }
99108
109+ @ Benchmark
110+ public int binarySquareScalar () {
111+ return VectorUtil .squareDistance (bytesA , bytesB );
112+ }
113+
114+ @ Benchmark
115+ @ Fork (jvmArgsPrepend = {"--add-modules=jdk.incubator.vector" })
116+ public int binarySquareVector () {
117+ return VectorUtil .squareDistance (bytesA , bytesB );
118+ }
119+
120+ @ Benchmark
121+ public int binaryHalfByteSquareScalar () {
122+ int v = VectorUtil .int4SquareDistance (halfBytesA , halfBytesB );
123+ if (v != expectedHalfByteSquareDistance ) {
124+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
125+ }
126+ return v ;
127+ }
128+
129+ @ Benchmark
130+ @ Fork (jvmArgsPrepend = {"--add-modules=jdk.incubator.vector" })
131+ public int binaryHalfByteSquareVector () {
132+ int v = VectorUtil .int4SquareDistance (halfBytesA , halfBytesB );
133+ if (v != expectedHalfByteSquareDistance ) {
134+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
135+ }
136+ return v ;
137+ }
138+
139+ @ Benchmark
140+ public int binaryHalfByteSquareSinglePackedScalar () {
141+ int v = VectorUtil .int4SquareDistanceSinglePacked (halfBytesA , halfBytesBPacked );
142+ if (v != expectedHalfByteSquareDistance ) {
143+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
144+ }
145+ return v ;
146+ }
147+
148+ @ Benchmark
149+ @ Fork (jvmArgsPrepend = {"--add-modules=jdk.incubator.vector" })
150+ public int binaryHalfByteSquareSinglePackedVector () {
151+ int v = VectorUtil .int4SquareDistanceSinglePacked (halfBytesA , halfBytesBPacked );
152+ if (v != expectedHalfByteSquareDistance ) {
153+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
154+ }
155+ return v ;
156+ }
157+
158+ @ Benchmark
159+ public int binaryHalfByteSquareBothPackedScalar () {
160+ int v = VectorUtil .int4SquareDistanceBothPacked (halfBytesAPacked , halfBytesBPacked );
161+ if (v != expectedHalfByteSquareDistance ) {
162+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
163+ }
164+ return v ;
165+ }
166+
167+ @ Benchmark
168+ @ Fork (jvmArgsPrepend = {"--add-modules=jdk.incubator.vector" })
169+ public int binaryHalfByteSquareBothPackedVector () {
170+ int v = VectorUtil .int4SquareDistanceBothPacked (halfBytesAPacked , halfBytesBPacked );
171+ if (v != expectedHalfByteSquareDistance ) {
172+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
173+ }
174+ return v ;
175+ }
176+
100177 @ Benchmark
101178 public int binaryDotProductScalar () {
102179 return VectorUtil .dotProduct (bytesA , bytesB );
@@ -120,14 +197,22 @@ public int binaryDotProductUint8Vector() {
120197 }
121198
122199 @ Benchmark
123- public int binarySquareScalar () {
124- return VectorUtil .squareDistance (bytesA , bytesB );
200+ public int binaryHalfByteDotProductScalar () {
201+ int v = VectorUtil .int4DotProduct (halfBytesA , halfBytesB );
202+ if (v != expectedHalfByteDotProduct ) {
203+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
204+ }
205+ return v ;
125206 }
126207
127208 @ Benchmark
128209 @ Fork (jvmArgsPrepend = {"--add-modules=jdk.incubator.vector" })
129- public int binarySquareVector () {
130- return VectorUtil .squareDistance (bytesA , bytesB );
210+ public int binaryHalfByteDotProductVector () {
211+ int v = VectorUtil .int4DotProduct (halfBytesA , halfBytesB );
212+ if (v != expectedHalfByteDotProduct ) {
213+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
214+ }
215+ return v ;
131216 }
132217
133218 @ Benchmark
@@ -142,37 +227,39 @@ public int binarySquareUint8Vector() {
142227 }
143228
144229 @ Benchmark
145- public int binaryHalfByteScalar () {
146- return VectorUtil .int4DotProduct (halfBytesA , halfBytesB );
230+ public int binaryHalfByteDotProductSinglePackedScalar () {
231+ int v = VectorUtil .int4DotProductSinglePacked (halfBytesA , halfBytesBPacked );
232+ if (v != expectedHalfByteDotProduct ) {
233+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
234+ }
235+ return v ;
147236 }
148237
149238 @ Benchmark
150239 @ Fork (jvmArgsPrepend = {"--add-modules=jdk.incubator.vector" })
151- public int binaryHalfByteVector () {
152- return VectorUtil .int4DotProduct (halfBytesA , halfBytesB );
240+ public int binaryHalfByteDotProductSinglePackedVector () {
241+ int v = VectorUtil .int4DotProductSinglePacked (halfBytesA , halfBytesBPacked );
242+ if (v != expectedHalfByteDotProduct ) {
243+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
244+ }
245+ return v ;
153246 }
154247
155248 @ Benchmark
156- public int binaryHalfByteScalarPacked () {
157- if (size % 2 != 0 ) {
158- throw new RuntimeException ("Size must be even for this benchmark" );
159- }
160- int v = VectorUtil .int4DotProductPacked (halfBytesA , halfBytesBPacked );
161- if (v != expectedhalfByteDotProduct ) {
162- throw new RuntimeException ("Expected " + expectedhalfByteDotProduct + " but got " + v );
249+ public int binaryHalfByteDotProductBothPackedScalar () {
250+ int v = VectorUtil .int4DotProductBothPacked (halfBytesAPacked , halfBytesBPacked );
251+ if (v != expectedHalfByteDotProduct ) {
252+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
163253 }
164254 return v ;
165255 }
166256
167257 @ Benchmark
168258 @ Fork (jvmArgsPrepend = {"--add-modules=jdk.incubator.vector" })
169- public int binaryHalfByteVectorPacked () {
170- if (size % 2 != 0 ) {
171- throw new RuntimeException ("Size must be even for this benchmark" );
172- }
173- int v = VectorUtil .int4DotProductPacked (halfBytesA , halfBytesBPacked );
174- if (v != expectedhalfByteDotProduct ) {
175- throw new RuntimeException ("Expected " + expectedhalfByteDotProduct + " but got " + v );
259+ public int binaryHalfByteDotProductBothPackedVector () {
260+ int v = VectorUtil .int4DotProductBothPacked (halfBytesAPacked , halfBytesBPacked );
261+ if (v != expectedHalfByteDotProduct ) {
262+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
176263 }
177264 return v ;
178265 }
0 commit comments