Skip to content

Commit 2e6824f

Browse files
Merge branch 'main' into 2025/07/28/write-load-decider
2 parents 0824b59 + 1be098a commit 2e6824f

File tree

40 files changed

+823
-255
lines changed

40 files changed

+823
-255
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/TransposeHalfByteBenchmark.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,13 @@ public void transposeHalfByteLegacy(Blackhole bh) {
8383
bh.consume(packed);
8484
}
8585
}
86+
87+
@Benchmark
88+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
89+
public void transposeHalfBytePanama(Blackhole bh) {
90+
for (int i = 0; i < numVectors; i++) {
91+
BQSpaceUtils.transposeHalfByte(qVectors[i], packed);
92+
bh.consume(packed);
93+
}
94+
}
8695
}

docs/changelog/132675.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 132675
2+
summary: Add second max queue latency stat to `ClusterInfo`
3+
area: Allocation
4+
type: enhancement
5+
issues: []

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,4 +381,22 @@ public static void packAsBinary(int[] vector, byte[] packed) {
381381
}
382382
IMPL.packAsBinary(vector, packed);
383383
}
384+
385+
/**
386+
* The idea here is to organize the query vector bits such that the first bit
387+
* of every dimension is in the first set dimensions bits, or (dimensions/8) bytes. The second,
388+
* third, and fourth bits are in the second, third, and fourth set of dimensions bits,
389+
* respectively. This allows for direct bitwise comparisons with the stored index vectors through
390+
* summing the bitwise results with the relative required bit shifts.
391+
*
392+
* @param q the query vector, assumed to be half-byte quantized with values between 0 and 15
393+
* @param quantQueryByte the byte array to store the transposed query vector.
394+
*
395+
**/
396+
public static void transposeHalfByte(int[] q, byte[] quantQueryByte) {
397+
if (quantQueryByte.length * Byte.SIZE < 4 * q.length) {
398+
throw new IllegalArgumentException("packed array is too small: " + quantQueryByte.length * Byte.SIZE + " < " + 4 * q.length);
399+
}
400+
IMPL.transposeHalfByte(q, quantQueryByte);
401+
}
384402
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,4 +353,54 @@ public static void packAsBinaryImpl(int[] vector, byte[] packed) {
353353
}
354354
packed[index] = result;
355355
}
356+
357+
@Override
358+
public void transposeHalfByte(int[] q, byte[] quantQueryByte) {
359+
transposeHalfByteImpl(q, quantQueryByte);
360+
}
361+
362+
public static void transposeHalfByteImpl(int[] q, byte[] quantQueryByte) {
363+
int limit = q.length - 7;
364+
int i = 0;
365+
int index = 0;
366+
for (; i < limit; i += 8, index++) {
367+
assert q[i] >= 0 && q[i] <= 15;
368+
assert q[i + 1] >= 0 && q[i + 1] <= 15;
369+
assert q[i + 2] >= 0 && q[i + 2] <= 15;
370+
assert q[i + 3] >= 0 && q[i + 3] <= 15;
371+
assert q[i + 4] >= 0 && q[i + 4] <= 15;
372+
assert q[i + 5] >= 0 && q[i + 5] <= 15;
373+
assert q[i + 6] >= 0 && q[i + 6] <= 15;
374+
assert q[i + 7] >= 0 && q[i + 7] <= 15;
375+
int lowerByte = (q[i] & 1) << 7 | (q[i + 1] & 1) << 6 | (q[i + 2] & 1) << 5 | (q[i + 3] & 1) << 4 | (q[i + 4] & 1) << 3 | (q[i
376+
+ 5] & 1) << 2 | (q[i + 6] & 1) << 1 | (q[i + 7] & 1);
377+
int lowerMiddleByte = ((q[i] >> 1) & 1) << 7 | ((q[i + 1] >> 1) & 1) << 6 | ((q[i + 2] >> 1) & 1) << 5 | ((q[i + 3] >> 1) & 1)
378+
<< 4 | ((q[i + 4] >> 1) & 1) << 3 | ((q[i + 5] >> 1) & 1) << 2 | ((q[i + 6] >> 1) & 1) << 1 | ((q[i + 7] >> 1) & 1);
379+
int upperMiddleByte = ((q[i] >> 2) & 1) << 7 | ((q[i + 1] >> 2) & 1) << 6 | ((q[i + 2] >> 2) & 1) << 5 | ((q[i + 3] >> 2) & 1)
380+
<< 4 | ((q[i + 4] >> 2) & 1) << 3 | ((q[i + 5] >> 2) & 1) << 2 | ((q[i + 6] >> 2) & 1) << 1 | ((q[i + 7] >> 2) & 1);
381+
int upperByte = ((q[i] >> 3) & 1) << 7 | ((q[i + 1] >> 3) & 1) << 6 | ((q[i + 2] >> 3) & 1) << 5 | ((q[i + 3] >> 3) & 1) << 4
382+
| ((q[i + 4] >> 3) & 1) << 3 | ((q[i + 5] >> 3) & 1) << 2 | ((q[i + 6] >> 3) & 1) << 1 | ((q[i + 7] >> 3) & 1);
383+
quantQueryByte[index] = (byte) lowerByte;
384+
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
385+
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
386+
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
387+
}
388+
if (i == q.length) {
389+
return; // all done
390+
}
391+
int lowerByte = 0;
392+
int lowerMiddleByte = 0;
393+
int upperMiddleByte = 0;
394+
int upperByte = 0;
395+
for (int j = 7; i < q.length; j--, i++) {
396+
lowerByte |= (q[i] & 1) << j;
397+
lowerMiddleByte |= ((q[i] >> 1) & 1) << j;
398+
upperMiddleByte |= ((q[i] >> 2) & 1) << j;
399+
upperByte |= ((q[i] >> 3) & 1) << j;
400+
}
401+
quantQueryByte[index] = (byte) lowerByte;
402+
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
403+
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
404+
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
405+
}
356406
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,6 @@ void soarDistanceBulk(
6565
);
6666

6767
void packAsBinary(int[] vector, byte[] packed);
68+
69+
void transposeHalfByte(int[] q, byte[] quantQueryByte);
6870
}

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.apache.lucene.util.Constants;
2323

2424
import static jdk.incubator.vector.VectorOperators.ADD;
25+
import static jdk.incubator.vector.VectorOperators.ASHR;
2526
import static jdk.incubator.vector.VectorOperators.LSHL;
2627
import static jdk.incubator.vector.VectorOperators.MAX;
2728
import static jdk.incubator.vector.VectorOperators.MIN;
@@ -1021,4 +1022,104 @@ private void packAsBinary128(int[] vector, byte[] packed) {
10211022
}
10221023
packed[index] = result;
10231024
}
1025+
1026+
@Override
1027+
public void transposeHalfByte(int[] q, byte[] quantQueryByte) {
1028+
// 128 / 32 == 4
1029+
if (q.length >= 8 && HAS_FAST_INTEGER_VECTORS) {
1030+
if (VECTOR_BITSIZE >= 256) {
1031+
transposeHalfByte256(q, quantQueryByte);
1032+
return;
1033+
} else if (VECTOR_BITSIZE == 128) {
1034+
transposeHalfByte128(q, quantQueryByte);
1035+
return;
1036+
}
1037+
}
1038+
DefaultESVectorUtilSupport.transposeHalfByteImpl(q, quantQueryByte);
1039+
}
1040+
1041+
private void transposeHalfByte256(int[] q, byte[] quantQueryByte) {
1042+
final int limit = INT_SPECIES_256.loopBound(q.length);
1043+
int i = 0;
1044+
int index = 0;
1045+
for (; i < limit; i += INT_SPECIES_256.length(), index++) {
1046+
IntVector v = IntVector.fromArray(INT_SPECIES_256, q, i);
1047+
1048+
int lowerByte = v.and(1).lanewise(LSHL, SHIFTS_256).reduceLanes(VectorOperators.OR);
1049+
int lowerMiddleByte = v.lanewise(ASHR, 1).and(1).lanewise(LSHL, SHIFTS_256).reduceLanes(VectorOperators.OR);
1050+
int upperMiddleByte = v.lanewise(ASHR, 2).and(1).lanewise(LSHL, SHIFTS_256).reduceLanes(VectorOperators.OR);
1051+
int upperByte = v.lanewise(ASHR, 3).and(1).lanewise(LSHL, SHIFTS_256).reduceLanes(VectorOperators.OR);
1052+
1053+
quantQueryByte[index] = (byte) lowerByte;
1054+
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
1055+
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
1056+
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
1057+
1058+
}
1059+
if (i == q.length) {
1060+
return; // all done
1061+
}
1062+
int lowerByte = 0;
1063+
int lowerMiddleByte = 0;
1064+
int upperMiddleByte = 0;
1065+
int upperByte = 0;
1066+
for (int j = 7; i < q.length; j--, i++) {
1067+
lowerByte |= (q[i] & 1) << j;
1068+
lowerMiddleByte |= ((q[i] >> 1) & 1) << j;
1069+
upperMiddleByte |= ((q[i] >> 2) & 1) << j;
1070+
upperByte |= ((q[i] >> 3) & 1) << j;
1071+
}
1072+
quantQueryByte[index] = (byte) lowerByte;
1073+
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
1074+
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
1075+
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
1076+
}
1077+
1078+
private void transposeHalfByte128(int[] q, byte[] quantQueryByte) {
1079+
final int limit = INT_SPECIES_128.loopBound(q.length) - INT_SPECIES_128.length();
1080+
int i = 0;
1081+
int index = 0;
1082+
for (; i < limit; i += 2 * INT_SPECIES_128.length(), index++) {
1083+
IntVector v = IntVector.fromArray(INT_SPECIES_128, q, i);
1084+
1085+
var lowerByteHigh = v.and(1).lanewise(LSHL, HIGH_SHIFTS_128);
1086+
var lowerMiddleByteHigh = v.lanewise(ASHR, 1).and(1).lanewise(LSHL, HIGH_SHIFTS_128);
1087+
var upperMiddleByteHigh = v.lanewise(ASHR, 2).and(1).lanewise(LSHL, HIGH_SHIFTS_128);
1088+
var upperByteHigh = v.lanewise(ASHR, 3).and(1).lanewise(LSHL, HIGH_SHIFTS_128);
1089+
1090+
v = IntVector.fromArray(INT_SPECIES_128, q, i + INT_SPECIES_128.length());
1091+
var lowerByteLow = v.and(1).lanewise(LSHL, LOW_SHIFTS_128);
1092+
var lowerMiddleByteLow = v.lanewise(ASHR, 1).and(1).lanewise(LSHL, LOW_SHIFTS_128);
1093+
var upperMiddleByteLow = v.lanewise(ASHR, 2).and(1).lanewise(LSHL, LOW_SHIFTS_128);
1094+
var upperByteLow = v.lanewise(ASHR, 3).and(1).lanewise(LSHL, LOW_SHIFTS_128);
1095+
1096+
int lowerByte = lowerByteHigh.lanewise(OR, lowerByteLow).reduceLanes(OR);
1097+
int lowerMiddleByte = lowerMiddleByteHigh.lanewise(OR, lowerMiddleByteLow).reduceLanes(OR);
1098+
int upperMiddleByte = upperMiddleByteHigh.lanewise(OR, upperMiddleByteLow).reduceLanes(OR);
1099+
int upperByte = upperByteHigh.lanewise(OR, upperByteLow).reduceLanes(OR);
1100+
1101+
quantQueryByte[index] = (byte) lowerByte;
1102+
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
1103+
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
1104+
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
1105+
1106+
}
1107+
if (i == q.length) {
1108+
return; // all done
1109+
}
1110+
int lowerByte = 0;
1111+
int lowerMiddleByte = 0;
1112+
int upperMiddleByte = 0;
1113+
int upperByte = 0;
1114+
for (int j = 7; i < q.length; j--, i++) {
1115+
lowerByte |= (q[i] & 1) << j;
1116+
lowerMiddleByte |= ((q[i] >> 1) & 1) << j;
1117+
upperMiddleByte |= ((q[i] >> 2) & 1) << j;
1118+
upperByte |= ((q[i] >> 3) & 1) << j;
1119+
}
1120+
quantQueryByte[index] = (byte) lowerByte;
1121+
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
1122+
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
1123+
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
1124+
}
10241125
}

libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,20 @@ public void testPackAsBinary() {
370370
assertArrayEquals(packedLegacy, packed);
371371
}
372372

373+
public void testTransposeHalfByte() {
374+
int dims = randomIntBetween(16, 2048);
375+
int[] toPack = new int[dims];
376+
for (int i = 0; i < dims; i++) {
377+
toPack[i] = randomInt(15);
378+
}
379+
int length = 4 * BQVectorUtils.discretize(dims, 64) / 8;
380+
byte[] packed = new byte[length];
381+
byte[] packedLegacy = new byte[length];
382+
defaultedProvider.getVectorUtilSupport().transposeHalfByte(toPack, packedLegacy);
383+
defOrPanamaProvider.getVectorUtilSupport().transposeHalfByte(toPack, packed);
384+
assertArrayEquals(packedLegacy, packed);
385+
}
386+
373387
private float[] generateRandomVector(int size) {
374388
float[] vector = new float[size];
375389
for (int i = 0; i < size; ++i) {

qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ record CmdLineArgs(
3636
KnnIndexTester.IndexType indexType,
3737
int numCandidates,
3838
int k,
39-
int[] nProbes,
39+
double[] visitPercentages,
4040
int ivfClusterSize,
4141
int overSamplingFactor,
4242
int hnswM,
@@ -63,7 +63,8 @@ record CmdLineArgs(
6363
static final ParseField INDEX_TYPE_FIELD = new ParseField("index_type");
6464
static final ParseField NUM_CANDIDATES_FIELD = new ParseField("num_candidates");
6565
static final ParseField K_FIELD = new ParseField("k");
66-
static final ParseField N_PROBE_FIELD = new ParseField("n_probe");
66+
// static final ParseField N_PROBE_FIELD = new ParseField("n_probe");
67+
static final ParseField VISIT_PERCENTAGE_FIELD = new ParseField("visit_percentage");
6768
static final ParseField IVF_CLUSTER_SIZE_FIELD = new ParseField("ivf_cluster_size");
6869
static final ParseField OVER_SAMPLING_FACTOR_FIELD = new ParseField("over_sampling_factor");
6970
static final ParseField HNSW_M_FIELD = new ParseField("hnsw_m");
@@ -97,7 +98,8 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
9798
PARSER.declareString(Builder::setIndexType, INDEX_TYPE_FIELD);
9899
PARSER.declareInt(Builder::setNumCandidates, NUM_CANDIDATES_FIELD);
99100
PARSER.declareInt(Builder::setK, K_FIELD);
100-
PARSER.declareIntArray(Builder::setNProbe, N_PROBE_FIELD);
101+
// PARSER.declareIntArray(Builder::setNProbe, N_PROBE_FIELD);
102+
PARSER.declareDoubleArray(Builder::setVisitPercentages, VISIT_PERCENTAGE_FIELD);
101103
PARSER.declareInt(Builder::setIvfClusterSize, IVF_CLUSTER_SIZE_FIELD);
102104
PARSER.declareInt(Builder::setOverSamplingFactor, OVER_SAMPLING_FACTOR_FIELD);
103105
PARSER.declareInt(Builder::setHnswM, HNSW_M_FIELD);
@@ -132,7 +134,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
132134
builder.field(INDEX_TYPE_FIELD.getPreferredName(), indexType.name().toLowerCase(Locale.ROOT));
133135
builder.field(NUM_CANDIDATES_FIELD.getPreferredName(), numCandidates);
134136
builder.field(K_FIELD.getPreferredName(), k);
135-
builder.field(N_PROBE_FIELD.getPreferredName(), nProbes);
137+
// builder.field(N_PROBE_FIELD.getPreferredName(), nProbes);
138+
builder.field(VISIT_PERCENTAGE_FIELD.getPreferredName(), visitPercentages);
136139
builder.field(IVF_CLUSTER_SIZE_FIELD.getPreferredName(), ivfClusterSize);
137140
builder.field(OVER_SAMPLING_FACTOR_FIELD.getPreferredName(), overSamplingFactor);
138141
builder.field(HNSW_M_FIELD.getPreferredName(), hnswM);
@@ -165,7 +168,7 @@ static class Builder {
165168
private KnnIndexTester.IndexType indexType = KnnIndexTester.IndexType.HNSW;
166169
private int numCandidates = 1000;
167170
private int k = 10;
168-
private int[] nProbes = new int[] { 10 };
171+
private double[] visitPercentages = new double[] { 1.0 };
169172
private int ivfClusterSize = 1000;
170173
private int overSamplingFactor = 1;
171174
private int hnswM = 16;
@@ -223,8 +226,8 @@ public Builder setK(int k) {
223226
return this;
224227
}
225228

226-
public Builder setNProbe(List<Integer> nProbes) {
227-
this.nProbes = nProbes.stream().mapToInt(Integer::intValue).toArray();
229+
public Builder setVisitPercentages(List<Double> visitPercentages) {
230+
this.visitPercentages = visitPercentages.stream().mapToDouble(Double::doubleValue).toArray();
228231
return this;
229232
}
230233

@@ -330,7 +333,7 @@ public CmdLineArgs build() {
330333
indexType,
331334
numCandidates,
332335
k,
333-
nProbes,
336+
visitPercentages,
334337
ivfClusterSize,
335338
overSamplingFactor,
336339
hnswM,

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -191,18 +191,18 @@ public static void main(String[] args) throws Exception {
191191
FormattedResults formattedResults = new FormattedResults();
192192

193193
for (CmdLineArgs cmdLineArgs : cmdLineArgsList) {
194-
int[] nProbes = cmdLineArgs.indexType().equals(IndexType.IVF) && cmdLineArgs.numQueries() > 0
195-
? cmdLineArgs.nProbes()
196-
: new int[] { 0 };
194+
double[] visitPercentages = cmdLineArgs.indexType().equals(IndexType.IVF) && cmdLineArgs.numQueries() > 0
195+
? cmdLineArgs.visitPercentages()
196+
: new double[] { 0 };
197197
String indexType = cmdLineArgs.indexType().name().toLowerCase(Locale.ROOT);
198198
Results indexResults = new Results(
199199
cmdLineArgs.docVectors().get(0).getFileName().toString(),
200200
indexType,
201201
cmdLineArgs.numDocs(),
202202
cmdLineArgs.filterSelectivity()
203203
);
204-
Results[] results = new Results[nProbes.length];
205-
for (int i = 0; i < nProbes.length; i++) {
204+
Results[] results = new Results[visitPercentages.length];
205+
for (int i = 0; i < visitPercentages.length; i++) {
206206
results[i] = new Results(
207207
cmdLineArgs.docVectors().get(0).getFileName().toString(),
208208
indexType,
@@ -240,8 +240,7 @@ public static void main(String[] args) throws Exception {
240240
numSegments(indexPath, indexResults);
241241
if (cmdLineArgs.queryVectors() != null && cmdLineArgs.numQueries() > 0) {
242242
for (int i = 0; i < results.length; i++) {
243-
int nProbe = nProbes[i];
244-
KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs, nProbe);
243+
KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs, visitPercentages[i]);
245244
knnSearcher.runSearch(results[i], cmdLineArgs.earlyTermination());
246245
}
247246
}
@@ -293,7 +292,7 @@ public String toString() {
293292
String[] searchHeaders = {
294293
"index_name",
295294
"index_type",
296-
"n_probe",
295+
"visit_percentage(%)",
297296
"latency(ms)",
298297
"net_cpu_time(ms)",
299298
"avg_cpu_count",
@@ -324,7 +323,7 @@ public String toString() {
324323
queryResultsArray[i] = new String[] {
325324
queryResult.indexName,
326325
queryResult.indexType,
327-
Integer.toString(queryResult.nProbe),
326+
String.format(Locale.ROOT, "%.2f", queryResult.visitPercentage),
328327
String.format(Locale.ROOT, "%.2f", queryResult.avgLatency),
329328
String.format(Locale.ROOT, "%.2f", queryResult.netCpuTimeMS),
330329
String.format(Locale.ROOT, "%.2f", queryResult.avgCpuCount),
@@ -400,7 +399,7 @@ static class Results {
400399
long indexTimeMS;
401400
long forceMergeTimeMS;
402401
int numSegments;
403-
int nProbe;
402+
double visitPercentage;
404403
double avgLatency;
405404
double qps;
406405
double avgRecall;

0 commit comments

Comments
 (0)