Skip to content

Commit 40e55b0

Browse files
rmuiruschindler
andauthored
Speed up vectorutil float scalar methods, unroll properly, use fma where possible (#12737)
Co-authored-by: Uwe Schindler <[email protected]>
1 parent b8a9b0a commit 40e55b0

File tree

5 files changed

+208
-109
lines changed

5 files changed

+208
-109
lines changed

lucene/core/src/java/org/apache/lucene/internal/vectorization/DefaultVectorUtilSupport.java

Lines changed: 96 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -17,72 +17,46 @@
1717

1818
package org.apache.lucene.internal.vectorization;
1919

20+
import org.apache.lucene.util.Constants;
21+
import org.apache.lucene.util.SuppressForbidden;
22+
2023
final class DefaultVectorUtilSupport implements VectorUtilSupport {
2124

2225
DefaultVectorUtilSupport() {}
2326

27+
// the way FMA should work! if available use it, otherwise fall back to mul/add
28+
@SuppressForbidden(reason = "Uses FMA only where fast and carefully contained")
29+
private static float fma(float a, float b, float c) {
30+
if (Constants.HAS_FAST_SCALAR_FMA) {
31+
return Math.fma(a, b, c);
32+
} else {
33+
return a * b + c;
34+
}
35+
}
36+
2437
@Override
2538
public float dotProduct(float[] a, float[] b) {
2639
float res = 0f;
27-
/*
28-
* If length of vector is larger than 8, we use unrolled dot product to accelerate the
29-
* calculation.
30-
*/
31-
int i;
32-
for (i = 0; i < a.length % 8; i++) {
33-
res += b[i] * a[i];
34-
}
35-
if (a.length < 8) {
36-
return res;
37-
}
38-
for (; i + 31 < a.length; i += 32) {
39-
res +=
40-
b[i + 0] * a[i + 0]
41-
+ b[i + 1] * a[i + 1]
42-
+ b[i + 2] * a[i + 2]
43-
+ b[i + 3] * a[i + 3]
44-
+ b[i + 4] * a[i + 4]
45-
+ b[i + 5] * a[i + 5]
46-
+ b[i + 6] * a[i + 6]
47-
+ b[i + 7] * a[i + 7];
48-
res +=
49-
b[i + 8] * a[i + 8]
50-
+ b[i + 9] * a[i + 9]
51-
+ b[i + 10] * a[i + 10]
52-
+ b[i + 11] * a[i + 11]
53-
+ b[i + 12] * a[i + 12]
54-
+ b[i + 13] * a[i + 13]
55-
+ b[i + 14] * a[i + 14]
56-
+ b[i + 15] * a[i + 15];
57-
res +=
58-
b[i + 16] * a[i + 16]
59-
+ b[i + 17] * a[i + 17]
60-
+ b[i + 18] * a[i + 18]
61-
+ b[i + 19] * a[i + 19]
62-
+ b[i + 20] * a[i + 20]
63-
+ b[i + 21] * a[i + 21]
64-
+ b[i + 22] * a[i + 22]
65-
+ b[i + 23] * a[i + 23];
66-
res +=
67-
b[i + 24] * a[i + 24]
68-
+ b[i + 25] * a[i + 25]
69-
+ b[i + 26] * a[i + 26]
70-
+ b[i + 27] * a[i + 27]
71-
+ b[i + 28] * a[i + 28]
72-
+ b[i + 29] * a[i + 29]
73-
+ b[i + 30] * a[i + 30]
74-
+ b[i + 31] * a[i + 31];
40+
int i = 0;
41+
42+
// if the array is big, unroll it
43+
if (a.length > 32) {
44+
float acc1 = 0;
45+
float acc2 = 0;
46+
float acc3 = 0;
47+
float acc4 = 0;
48+
int upperBound = a.length & ~(4 - 1);
49+
for (; i < upperBound; i += 4) {
50+
acc1 = fma(a[i], b[i], acc1);
51+
acc2 = fma(a[i + 1], b[i + 1], acc2);
52+
acc3 = fma(a[i + 2], b[i + 2], acc3);
53+
acc4 = fma(a[i + 3], b[i + 3], acc4);
54+
}
55+
res += acc1 + acc2 + acc3 + acc4;
7556
}
76-
for (; i + 7 < a.length; i += 8) {
77-
res +=
78-
b[i + 0] * a[i + 0]
79-
+ b[i + 1] * a[i + 1]
80-
+ b[i + 2] * a[i + 2]
81-
+ b[i + 3] * a[i + 3]
82-
+ b[i + 4] * a[i + 4]
83-
+ b[i + 5] * a[i + 5]
84-
+ b[i + 6] * a[i + 6]
85-
+ b[i + 7] * a[i + 7];
57+
58+
for (; i < a.length; i++) {
59+
res = fma(a[i], b[i], res);
8660
}
8761
return res;
8862
}
@@ -92,50 +66,80 @@ public float cosine(float[] a, float[] b) {
9266
float sum = 0.0f;
9367
float norm1 = 0.0f;
9468
float norm2 = 0.0f;
95-
int dim = a.length;
69+
int i = 0;
9670

97-
for (int i = 0; i < dim; i++) {
98-
float elem1 = a[i];
99-
float elem2 = b[i];
100-
sum += elem1 * elem2;
101-
norm1 += elem1 * elem1;
102-
norm2 += elem2 * elem2;
71+
// if the array is big, unroll it
72+
if (a.length > 32) {
73+
float sum1 = 0;
74+
float sum2 = 0;
75+
float norm1_1 = 0;
76+
float norm1_2 = 0;
77+
float norm2_1 = 0;
78+
float norm2_2 = 0;
79+
80+
int upperBound = a.length & ~(2 - 1);
81+
for (; i < upperBound; i += 2) {
82+
// one
83+
sum1 = fma(a[i], b[i], sum1);
84+
norm1_1 = fma(a[i], a[i], norm1_1);
85+
norm2_1 = fma(b[i], b[i], norm2_1);
86+
87+
// two
88+
sum2 = fma(a[i + 1], b[i + 1], sum2);
89+
norm1_2 = fma(a[i + 1], a[i + 1], norm1_2);
90+
norm2_2 = fma(b[i + 1], b[i + 1], norm2_2);
91+
}
92+
sum += sum1 + sum2;
93+
norm1 += norm1_1 + norm1_2;
94+
norm2 += norm2_1 + norm2_2;
95+
}
96+
97+
for (; i < a.length; i++) {
98+
sum = fma(a[i], b[i], sum);
99+
norm1 = fma(a[i], a[i], norm1);
100+
norm2 = fma(b[i], b[i], norm2);
103101
}
104102
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
105103
}
106104

107105
@Override
108106
public float squareDistance(float[] a, float[] b) {
109-
float squareSum = 0.0f;
110-
int dim = a.length;
111-
int i;
112-
for (i = 0; i + 8 <= dim; i += 8) {
113-
squareSum += squareDistanceUnrolled(a, b, i);
107+
float res = 0;
108+
int i = 0;
109+
110+
// if the array is big, unroll it
111+
if (a.length > 32) {
112+
float acc1 = 0;
113+
float acc2 = 0;
114+
float acc3 = 0;
115+
float acc4 = 0;
116+
117+
int upperBound = a.length & ~(4 - 1);
118+
for (; i < upperBound; i += 4) {
119+
// one
120+
float diff1 = a[i] - b[i];
121+
acc1 = fma(diff1, diff1, acc1);
122+
123+
// two
124+
float diff2 = a[i + 1] - b[i + 1];
125+
acc2 = fma(diff2, diff2, acc2);
126+
127+
// three
128+
float diff3 = a[i + 2] - b[i + 2];
129+
acc3 = fma(diff3, diff3, acc3);
130+
131+
// four
132+
float diff4 = a[i + 3] - b[i + 3];
133+
acc4 = fma(diff4, diff4, acc4);
134+
}
135+
res += acc1 + acc2 + acc3 + acc4;
114136
}
115-
for (; i < dim; i++) {
137+
138+
for (; i < a.length; i++) {
116139
float diff = a[i] - b[i];
117-
squareSum += diff * diff;
140+
res = fma(diff, diff, res);
118141
}
119-
return squareSum;
120-
}
121-
122-
private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) {
123-
float diff0 = v1[index + 0] - v2[index + 0];
124-
float diff1 = v1[index + 1] - v2[index + 1];
125-
float diff2 = v1[index + 2] - v2[index + 2];
126-
float diff3 = v1[index + 3] - v2[index + 3];
127-
float diff4 = v1[index + 4] - v2[index + 4];
128-
float diff5 = v1[index + 5] - v2[index + 5];
129-
float diff6 = v1[index + 6] - v2[index + 6];
130-
float diff7 = v1[index + 7] - v2[index + 7];
131-
return diff0 * diff0
132-
+ diff1 * diff1
133-
+ diff2 * diff2
134-
+ diff3 * diff3
135-
+ diff4 * diff4
136-
+ diff5 * diff5
137-
+ diff6 * diff6
138-
+ diff7 * diff7;
142+
return res;
139143
}
140144

141145
@Override

lucene/core/src/java/org/apache/lucene/util/Constants.java

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import java.security.AccessController;
2020
import java.security.PrivilegedAction;
21-
import java.util.Objects;
2221
import java.util.logging.Logger;
2322

2423
/** Some useful constants. */
@@ -67,12 +66,6 @@ private Constants() {} // can't construct
6766
/** True iff running on a 64bit JVM */
6867
public static final boolean JRE_IS_64BIT = is64Bit();
6968

70-
/** true iff we know fast FMA is supported, to deliver less error */
71-
public static final boolean HAS_FAST_FMA =
72-
(IS_CLIENT_VM == false)
73-
&& Objects.equals(OS_ARCH, "amd64")
74-
&& HotspotVMOptions.get("UseFMA").map(Boolean::valueOf).orElse(false);
75-
7669
private static boolean is64Bit() {
7770
final String datamodel = getSysProp("sun.arch.data.model");
7871
if (datamodel != null) {
@@ -82,6 +75,76 @@ private static boolean is64Bit() {
8275
}
8376
}
8477

78+
/** true if FMA likely means a cpu instruction and not BigDecimal logic */
79+
private static final boolean HAS_FMA =
80+
(IS_CLIENT_VM == false) && HotspotVMOptions.get("UseFMA").map(Boolean::valueOf).orElse(false);
81+
82+
/** maximum supported vectorsize */
83+
private static final int MAX_VECTOR_SIZE =
84+
HotspotVMOptions.get("MaxVectorSize").map(Integer::valueOf).orElse(0);
85+
86+
/** true for an AMD cpu with SSE4a instructions */
87+
private static final boolean HAS_SSE4A =
88+
HotspotVMOptions.get("UseXmmI2F").map(Boolean::valueOf).orElse(false);
89+
90+
/** true iff we know VFMA has faster throughput than separate vmul/vadd */
91+
public static final boolean HAS_FAST_VECTOR_FMA = hasFastVectorFMA();
92+
93+
/** true iff we know FMA has faster throughput than separate mul/add */
94+
public static final boolean HAS_FAST_SCALAR_FMA = hasFastScalarFMA();
95+
96+
private static boolean hasFastVectorFMA() {
97+
if (HAS_FMA) {
98+
String value = getSysProp("lucene.useVectorFMA", "auto");
99+
if ("auto".equals(value)) {
100+
// newer Neoverse cores have their act together
101+
// the problem is just apple silicon (this is a practical heuristic)
102+
if (OS_ARCH.equals("aarch64") && MAC_OS_X == false) {
103+
return true;
104+
}
105+
// zen cores or newer, its a wash, turn it on as it doesn't hurt
106+
// starts to yield gains for vectors only at zen4+
107+
if (HAS_SSE4A && MAX_VECTOR_SIZE >= 32) {
108+
return true;
109+
}
110+
// intel has their act together
111+
if (OS_ARCH.equals("amd64") && HAS_SSE4A == false) {
112+
return true;
113+
}
114+
} else {
115+
return Boolean.parseBoolean(value);
116+
}
117+
}
118+
// everyone else is slow, until proven otherwise by benchmarks
119+
return false;
120+
}
121+
122+
private static boolean hasFastScalarFMA() {
123+
if (HAS_FMA) {
124+
String value = getSysProp("lucene.useScalarFMA", "auto");
125+
if ("auto".equals(value)) {
126+
// newer Neoverse cores have their act together
127+
// the problem is just apple silicon (this is a practical heuristic)
128+
if (OS_ARCH.equals("aarch64") && MAC_OS_X == false) {
129+
return true;
130+
}
131+
// latency becomes 4 for the Zen3 (0x19h), but still a wash
132+
// until the Zen4 anyway, and big drop on previous zens:
133+
if (HAS_SSE4A && MAX_VECTOR_SIZE >= 64) {
134+
return true;
135+
}
136+
// intel has their act together
137+
if (OS_ARCH.equals("amd64") && HAS_SSE4A == false) {
138+
return true;
139+
}
140+
} else {
141+
return Boolean.parseBoolean(value);
142+
}
143+
}
144+
// everyone else is slow, until proven otherwise by benchmarks
145+
return false;
146+
}
147+
85148
private static String getSysProp(String property) {
86149
try {
87150
return doPrivileged(() -> System.getProperty(property));

lucene/core/src/java/org/apache/lucene/util/VectorUtil.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,31 @@
2020
import org.apache.lucene.internal.vectorization.VectorUtilSupport;
2121
import org.apache.lucene.internal.vectorization.VectorizationProvider;
2222

23-
/** Utilities for computations with numeric arrays */
23+
/**
24+
* Utilities for computations with numeric arrays, especially algebraic operations like vector dot
25+
* products. This class uses SIMD vectorization if the corresponding Java module is available and
26+
* enabled. To enable vectorized code, pass {@code --add-modules jdk.incubator.vector} to Java's
27+
* command line.
28+
*
29+
* <p>It will use CPU's <a href="https://en.wikipedia.org/wiki/Fused_multiply%E2%80%93add">FMA
30+
* instructions</a> if it is known to perform faster than separate multiply+add. This requires at
31+
* least Hotspot C2 enabled, which is the default for OpenJDK based JVMs.
32+
*
33+
* <p>To explicitly disable or enable FMA usage, pass the following system properties:
34+
*
35+
* <ul>
36+
* <li>{@code -Dlucene.useScalarFMA=(auto|true|false)} for scalar operations
37+
* <li>{@code -Dlucene.useVectorFMA=(auto|true|false)} for vectorized operations (with vector
38+
* incubator module)
39+
* </ul>
40+
*
41+
* <p>The default is {@code auto}, which enables this for known CPU types and JVM settings. If
42+
* Hotspot C2 is disabled, FMA and vectorization are <strong>not</strong> used.
43+
*
44+
* <p>Vectorization and FMA is only supported for Hotspot-based JVMs; it won't work on OpenJ9-based
45+
* JVMs unless they provide {@link com.sun.management.HotSpotDiagnosticMXBean}. Please also make
46+
* sure that you have the {@code jdk.management} module enabled in modularized applications.
47+
*/
2448
public final class VectorUtil {
2549

2650
private static final VectorUtilSupport IMPL =

0 commit comments

Comments
 (0)