1717
1818package org .apache .lucene .internal .vectorization ;
1919
20+ import org .apache .lucene .util .Constants ;
21+ import org .apache .lucene .util .SuppressForbidden ;
22+
2023final class DefaultVectorUtilSupport implements VectorUtilSupport {
2124
2225 DefaultVectorUtilSupport () {}
2326
27+ // the way FMA should work! if available use it, otherwise fall back to mul/add
28+ @ SuppressForbidden (reason = "Uses FMA only where fast and carefully contained" )
29+ private static float fma (float a , float b , float c ) {
30+ if (Constants .HAS_FAST_SCALAR_FMA ) {
31+ return Math .fma (a , b , c );
32+ } else {
33+ return a * b + c ;
34+ }
35+ }
36+
2437 @ Override
2538 public float dotProduct (float [] a , float [] b ) {
2639 float res = 0f ;
27- /*
28- * If length of vector is larger than 8, we use unrolled dot product to accelerate the
29- * calculation.
30- */
31- int i ;
32- for (i = 0 ; i < a .length % 8 ; i ++) {
33- res += b [i ] * a [i ];
34- }
35- if (a .length < 8 ) {
36- return res ;
37- }
38- for (; i + 31 < a .length ; i += 32 ) {
39- res +=
40- b [i + 0 ] * a [i + 0 ]
41- + b [i + 1 ] * a [i + 1 ]
42- + b [i + 2 ] * a [i + 2 ]
43- + b [i + 3 ] * a [i + 3 ]
44- + b [i + 4 ] * a [i + 4 ]
45- + b [i + 5 ] * a [i + 5 ]
46- + b [i + 6 ] * a [i + 6 ]
47- + b [i + 7 ] * a [i + 7 ];
48- res +=
49- b [i + 8 ] * a [i + 8 ]
50- + b [i + 9 ] * a [i + 9 ]
51- + b [i + 10 ] * a [i + 10 ]
52- + b [i + 11 ] * a [i + 11 ]
53- + b [i + 12 ] * a [i + 12 ]
54- + b [i + 13 ] * a [i + 13 ]
55- + b [i + 14 ] * a [i + 14 ]
56- + b [i + 15 ] * a [i + 15 ];
57- res +=
58- b [i + 16 ] * a [i + 16 ]
59- + b [i + 17 ] * a [i + 17 ]
60- + b [i + 18 ] * a [i + 18 ]
61- + b [i + 19 ] * a [i + 19 ]
62- + b [i + 20 ] * a [i + 20 ]
63- + b [i + 21 ] * a [i + 21 ]
64- + b [i + 22 ] * a [i + 22 ]
65- + b [i + 23 ] * a [i + 23 ];
66- res +=
67- b [i + 24 ] * a [i + 24 ]
68- + b [i + 25 ] * a [i + 25 ]
69- + b [i + 26 ] * a [i + 26 ]
70- + b [i + 27 ] * a [i + 27 ]
71- + b [i + 28 ] * a [i + 28 ]
72- + b [i + 29 ] * a [i + 29 ]
73- + b [i + 30 ] * a [i + 30 ]
74- + b [i + 31 ] * a [i + 31 ];
40+ int i = 0 ;
41+
42+ // if the array is big, unroll it
43+ if (a .length > 32 ) {
44+ float acc1 = 0 ;
45+ float acc2 = 0 ;
46+ float acc3 = 0 ;
47+ float acc4 = 0 ;
48+ int upperBound = a .length & ~(4 - 1 );
49+ for (; i < upperBound ; i += 4 ) {
50+ acc1 = fma (a [i ], b [i ], acc1 );
51+ acc2 = fma (a [i + 1 ], b [i + 1 ], acc2 );
52+ acc3 = fma (a [i + 2 ], b [i + 2 ], acc3 );
53+ acc4 = fma (a [i + 3 ], b [i + 3 ], acc4 );
54+ }
55+ res += acc1 + acc2 + acc3 + acc4 ;
7556 }
76- for (; i + 7 < a .length ; i += 8 ) {
77- res +=
78- b [i + 0 ] * a [i + 0 ]
79- + b [i + 1 ] * a [i + 1 ]
80- + b [i + 2 ] * a [i + 2 ]
81- + b [i + 3 ] * a [i + 3 ]
82- + b [i + 4 ] * a [i + 4 ]
83- + b [i + 5 ] * a [i + 5 ]
84- + b [i + 6 ] * a [i + 6 ]
85- + b [i + 7 ] * a [i + 7 ];
57+
58+ for (; i < a .length ; i ++) {
59+ res = fma (a [i ], b [i ], res );
8660 }
8761 return res ;
8862 }
@@ -92,50 +66,80 @@ public float cosine(float[] a, float[] b) {
9266 float sum = 0.0f ;
9367 float norm1 = 0.0f ;
9468 float norm2 = 0.0f ;
95- int dim = a . length ;
69+ int i = 0 ;
9670
97- for (int i = 0 ; i < dim ; i ++) {
98- float elem1 = a [i ];
99- float elem2 = b [i ];
100- sum += elem1 * elem2 ;
101- norm1 += elem1 * elem1 ;
102- norm2 += elem2 * elem2 ;
71+ // if the array is big, unroll it
72+ if (a .length > 32 ) {
73+ float sum1 = 0 ;
74+ float sum2 = 0 ;
75+ float norm1_1 = 0 ;
76+ float norm1_2 = 0 ;
77+ float norm2_1 = 0 ;
78+ float norm2_2 = 0 ;
79+
80+ int upperBound = a .length & ~(2 - 1 );
81+ for (; i < upperBound ; i += 2 ) {
82+ // one
83+ sum1 = fma (a [i ], b [i ], sum1 );
84+ norm1_1 = fma (a [i ], a [i ], norm1_1 );
85+ norm2_1 = fma (b [i ], b [i ], norm2_1 );
86+
87+ // two
88+ sum2 = fma (a [i + 1 ], b [i + 1 ], sum2 );
89+ norm1_2 = fma (a [i + 1 ], a [i + 1 ], norm1_2 );
90+ norm2_2 = fma (b [i + 1 ], b [i + 1 ], norm2_2 );
91+ }
92+ sum += sum1 + sum2 ;
93+ norm1 += norm1_1 + norm1_2 ;
94+ norm2 += norm2_1 + norm2_2 ;
95+ }
96+
97+ for (; i < a .length ; i ++) {
98+ sum = fma (a [i ], b [i ], sum );
99+ norm1 = fma (a [i ], a [i ], norm1 );
100+ norm2 = fma (b [i ], b [i ], norm2 );
103101 }
104102 return (float ) (sum / Math .sqrt ((double ) norm1 * (double ) norm2 ));
105103 }
106104
107105 @ Override
108106 public float squareDistance (float [] a , float [] b ) {
109- float squareSum = 0.0f ;
110- int dim = a .length ;
111- int i ;
112- for (i = 0 ; i + 8 <= dim ; i += 8 ) {
113- squareSum += squareDistanceUnrolled (a , b , i );
107+ float res = 0 ;
108+ int i = 0 ;
109+
110+ // if the array is big, unroll it
111+ if (a .length > 32 ) {
112+ float acc1 = 0 ;
113+ float acc2 = 0 ;
114+ float acc3 = 0 ;
115+ float acc4 = 0 ;
116+
117+ int upperBound = a .length & ~(4 - 1 );
118+ for (; i < upperBound ; i += 4 ) {
119+ // one
120+ float diff1 = a [i ] - b [i ];
121+ acc1 = fma (diff1 , diff1 , acc1 );
122+
123+ // two
124+ float diff2 = a [i + 1 ] - b [i + 1 ];
125+ acc2 = fma (diff2 , diff2 , acc2 );
126+
127+ // three
128+ float diff3 = a [i + 2 ] - b [i + 2 ];
129+ acc3 = fma (diff3 , diff3 , acc3 );
130+
131+ // four
132+ float diff4 = a [i + 3 ] - b [i + 3 ];
133+ acc4 = fma (diff4 , diff4 , acc4 );
134+ }
135+ res += acc1 + acc2 + acc3 + acc4 ;
114136 }
115- for (; i < dim ; i ++) {
137+
138+ for (; i < a .length ; i ++) {
116139 float diff = a [i ] - b [i ];
117- squareSum += diff * diff ;
140+ res = fma ( diff , diff , res ) ;
118141 }
119- return squareSum ;
120- }
121-
122- private static float squareDistanceUnrolled (float [] v1 , float [] v2 , int index ) {
123- float diff0 = v1 [index + 0 ] - v2 [index + 0 ];
124- float diff1 = v1 [index + 1 ] - v2 [index + 1 ];
125- float diff2 = v1 [index + 2 ] - v2 [index + 2 ];
126- float diff3 = v1 [index + 3 ] - v2 [index + 3 ];
127- float diff4 = v1 [index + 4 ] - v2 [index + 4 ];
128- float diff5 = v1 [index + 5 ] - v2 [index + 5 ];
129- float diff6 = v1 [index + 6 ] - v2 [index + 6 ];
130- float diff7 = v1 [index + 7 ] - v2 [index + 7 ];
131- return diff0 * diff0
132- + diff1 * diff1
133- + diff2 * diff2
134- + diff3 * diff3
135- + diff4 * diff4
136- + diff5 * diff5
137- + diff6 * diff6
138- + diff7 * diff7 ;
142+ return res ;
139143 }
140144
141145 @ Override
0 commit comments