@@ -54,11 +54,13 @@ static void compressBytes(byte[] raw, byte[] compressed) {
5454 private byte [] bytesA ;
5555 private byte [] bytesB ;
5656 private byte [] halfBytesA ;
57+ private byte [] halfBytesAPacked ;
5758 private byte [] halfBytesB ;
5859 private byte [] halfBytesBPacked ;
5960 private float [] floatsA ;
6061 private float [] floatsB ;
61- private int expectedhalfByteDotProduct ;
62+ private int expectedHalfByteDotProduct ;
63+ private int expectedHalfByteSquareDistance ;
6264
6365 @ Param ({"1" , "128" , "207" , "256" , "300" , "512" , "702" , "1024" })
6466 int size ;
@@ -74,16 +76,23 @@ public void init() {
7476 random .nextBytes (bytesB );
7577 // random half byte arrays for binary methods
7678 // this means that all values must be between 0 and 15
77- expectedhalfByteDotProduct = 0 ;
79+ expectedHalfByteDotProduct = 0 ;
80+ expectedHalfByteSquareDistance = 0 ;
7881 halfBytesA = new byte [size ];
7982 halfBytesB = new byte [size ];
8083 for (int i = 0 ; i < size ; ++i ) {
8184 halfBytesA [i ] = (byte ) random .nextInt (16 );
8285 halfBytesB [i ] = (byte ) random .nextInt (16 );
83- expectedhalfByteDotProduct += halfBytesA [i ] * halfBytesB [i ];
86+ expectedHalfByteDotProduct += halfBytesA [i ] * halfBytesB [i ];
87+
88+ int diff = halfBytesA [i ] - halfBytesB [i ];
89+ expectedHalfByteSquareDistance += diff * diff ;
8490 }
8591 // pack the half byte arrays
8692 if (size % 2 == 0 ) {
93+ halfBytesAPacked = new byte [(size + 1 ) >> 1 ];
94+ compressBytes (halfBytesA , halfBytesAPacked );
95+
8796 halfBytesBPacked = new byte [(size + 1 ) >> 1 ];
8897 compressBytes (halfBytesB , halfBytesBPacked );
8998 }
@@ -108,6 +117,74 @@ public float binaryCosineVector() {
108117 return VectorUtil .cosine (bytesA , bytesB );
109118 }
110119
120+ @ Benchmark
121+ public int binarySquareScalar () {
122+ return VectorUtil .squareDistance (bytesA , bytesB );
123+ }
124+
125+ @ Benchmark
126+ @ Fork (jvmArgsPrepend = {"--add-modules=jdk.incubator.vector" })
127+ public int binarySquareVector () {
128+ return VectorUtil .squareDistance (bytesA , bytesB );
129+ }
130+
131+ @ Benchmark
132+ public int binaryHalfByteSquareScalar () {
133+ int v = VectorUtil .int4SquareDistance (halfBytesA , halfBytesB );
134+ if (v != expectedHalfByteSquareDistance ) {
135+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
136+ }
137+ return v ;
138+ }
139+
140+ @ Benchmark
141+ @ Fork (jvmArgsPrepend = {"--add-modules=jdk.incubator.vector" })
142+ public int binaryHalfByteSquareVector () {
143+ int v = VectorUtil .int4SquareDistance (halfBytesA , halfBytesB );
144+ if (v != expectedHalfByteSquareDistance ) {
145+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
146+ }
147+ return v ;
148+ }
149+
150+ @ Benchmark
151+ public int binaryHalfByteSquareSinglePackedScalar () {
152+ int v = VectorUtil .int4SquareDistanceSinglePacked (halfBytesA , halfBytesBPacked );
153+ if (v != expectedHalfByteSquareDistance ) {
154+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
155+ }
156+ return v ;
157+ }
158+
159+ @ Benchmark
160+ @ Fork (jvmArgsPrepend = {"--add-modules=jdk.incubator.vector" })
161+ public int binaryHalfByteSquareSinglePackedVector () {
162+ int v = VectorUtil .int4SquareDistanceSinglePacked (halfBytesA , halfBytesBPacked );
163+ if (v != expectedHalfByteSquareDistance ) {
164+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
165+ }
166+ return v ;
167+ }
168+
169+ @ Benchmark
170+ public int binaryHalfByteSquareBothPackedScalar () {
171+ int v = VectorUtil .int4SquareDistanceBothPacked (halfBytesAPacked , halfBytesBPacked );
172+ if (v != expectedHalfByteSquareDistance ) {
173+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
174+ }
175+ return v ;
176+ }
177+
178+ @ Benchmark
179+ @ Fork (jvmArgsPrepend = {"--add-modules=jdk.incubator.vector" })
180+ public int binaryHalfByteSquareBothPackedVector () {
181+ int v = VectorUtil .int4SquareDistanceBothPacked (halfBytesAPacked , halfBytesBPacked );
182+ if (v != expectedHalfByteSquareDistance ) {
183+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
184+ }
185+ return v ;
186+ }
187+
111188 @ Benchmark
112189 public int binaryDotProductScalar () {
113190 return VectorUtil .dotProduct (bytesA , bytesB );
@@ -131,14 +208,22 @@ public int binaryDotProductUint8Vector() {
131208 }
132209
133210 @ Benchmark
134- public int binarySquareScalar () {
135- return VectorUtil .squareDistance (bytesA , bytesB );
211+ public int binaryHalfByteDotProductScalar () {
212+ int v = VectorUtil .int4DotProduct (halfBytesA , halfBytesB );
213+ if (v != expectedHalfByteDotProduct ) {
214+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
215+ }
216+ return v ;
136217 }
137218
138219 @ Benchmark
139220 @ Fork (jvmArgsPrepend = {"--add-modules=jdk.incubator.vector" })
140- public int binarySquareVector () {
141- return VectorUtil .squareDistance (bytesA , bytesB );
221+ public int binaryHalfByteDotProductVector () {
222+ int v = VectorUtil .int4DotProduct (halfBytesA , halfBytesB );
223+ if (v != expectedHalfByteDotProduct ) {
224+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
225+ }
226+ return v ;
142227 }
143228
144229 @ Benchmark
@@ -153,37 +238,39 @@ public int binarySquareUint8Vector() {
153238 }
154239
155240 @ Benchmark
156- public int binaryHalfByteScalar () {
157- return VectorUtil .int4DotProduct (halfBytesA , halfBytesB );
241+ public int binaryHalfByteDotProductSinglePackedScalar () {
242+ int v = VectorUtil .int4DotProductSinglePacked (halfBytesA , halfBytesBPacked );
243+ if (v != expectedHalfByteDotProduct ) {
244+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
245+ }
246+ return v ;
158247 }
159248
160249 @ Benchmark
161250 @ Fork (jvmArgsPrepend = {"--add-modules=jdk.incubator.vector" })
162- public int binaryHalfByteVector () {
163- return VectorUtil .int4DotProduct (halfBytesA , halfBytesB );
251+ public int binaryHalfByteDotProductSinglePackedVector () {
252+ int v = VectorUtil .int4DotProductSinglePacked (halfBytesA , halfBytesBPacked );
253+ if (v != expectedHalfByteDotProduct ) {
254+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
255+ }
256+ return v ;
164257 }
165258
166259 @ Benchmark
167- public int binaryHalfByteScalarPacked () {
168- if (size % 2 != 0 ) {
169- throw new RuntimeException ("Size must be even for this benchmark" );
170- }
171- int v = VectorUtil .int4DotProductPacked (halfBytesA , halfBytesBPacked );
172- if (v != expectedhalfByteDotProduct ) {
173- throw new RuntimeException ("Expected " + expectedhalfByteDotProduct + " but got " + v );
260+ public int binaryHalfByteDotProductBothPackedScalar () {
261+ int v = VectorUtil .int4DotProductBothPacked (halfBytesAPacked , halfBytesBPacked );
262+ if (v != expectedHalfByteDotProduct ) {
263+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
174264 }
175265 return v ;
176266 }
177267
178268 @ Benchmark
179269 @ Fork (jvmArgsPrepend = {"--add-modules=jdk.incubator.vector" })
180- public int binaryHalfByteVectorPacked () {
181- if (size % 2 != 0 ) {
182- throw new RuntimeException ("Size must be even for this benchmark" );
183- }
184- int v = VectorUtil .int4DotProductPacked (halfBytesA , halfBytesBPacked );
185- if (v != expectedhalfByteDotProduct ) {
186- throw new RuntimeException ("Expected " + expectedhalfByteDotProduct + " but got " + v );
270+ public int binaryHalfByteDotProductBothPackedVector () {
271+ int v = VectorUtil .int4DotProductBothPacked (halfBytesAPacked , halfBytesBPacked );
272+ if (v != expectedHalfByteDotProduct ) {
273+ throw new RuntimeException ("Expected " + expectedHalfByteDotProduct + " but got " + v );
187274 }
188275 return v ;
189276 }
0 commit comments