Skip to content

Commit 0d407cc

Browse files
authored
Merge branch 'main' into update-esql-docs
2 parents fac5115 + 9daa870 commit 0d407cc

File tree

136 files changed

+4615
-365
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

136 files changed

+4615
-365
lines changed

qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
8787
return builder.build();
8888
}
8989

90-
static final ObjectParser<CmdLineArgs.Builder, Void> PARSER = new ObjectParser<>("cmd_line_args", true, Builder::new);
90+
static final ObjectParser<CmdLineArgs.Builder, Void> PARSER = new ObjectParser<>("cmd_line_args", false, Builder::new);
9191

9292
static {
9393
PARSER.declareStringArray(Builder::setDocVectors, DOC_VECTORS_FIELD);

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515
import org.apache.lucene.codecs.KnnVectorsFormat;
1616
import org.apache.lucene.codecs.lucene101.Lucene101Codec;
1717
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
18+
import org.apache.lucene.index.DirectoryReader;
19+
import org.apache.lucene.index.IndexReader;
1820
import org.apache.lucene.index.LogByteSizeMergePolicy;
1921
import org.apache.lucene.index.LogDocMergePolicy;
2022
import org.apache.lucene.index.MergePolicy;
2123
import org.apache.lucene.index.NoMergePolicy;
2224
import org.apache.lucene.index.TieredMergePolicy;
25+
import org.apache.lucene.store.FSDirectory;
2326
import org.elasticsearch.cli.ProcessInfo;
2427
import org.elasticsearch.common.Strings;
2528
import org.elasticsearch.common.logging.LogConfigurator;
@@ -37,7 +40,9 @@
3740
import org.elasticsearch.xcontent.XContentParserConfiguration;
3841
import org.elasticsearch.xcontent.XContentType;
3942

43+
import java.io.IOException;
4044
import java.io.InputStream;
45+
import java.io.UncheckedIOException;
4146
import java.lang.management.ThreadInfo;
4247
import java.nio.file.Files;
4348
import java.nio.file.Path;
@@ -230,10 +235,9 @@ public static void main(String[] args) throws Exception {
230235
}
231236
if (cmdLineArgs.forceMerge()) {
232237
knnIndexer.forceMerge(indexResults);
233-
} else {
234-
knnIndexer.numSegments(indexResults);
235238
}
236239
}
240+
numSegments(indexPath, indexResults);
237241
if (cmdLineArgs.queryVectors() != null && cmdLineArgs.numQueries() > 0) {
238242
for (int i = 0; i < results.length; i++) {
239243
int nProbe = nProbes[i];
@@ -265,6 +269,14 @@ private static MergePolicy getMergePolicy(CmdLineArgs args) {
265269
return mergePolicy;
266270
}
267271

272+
static void numSegments(Path indexPath, KnnIndexTester.Results result) {
273+
try (FSDirectory dir = FSDirectory.open(indexPath); IndexReader reader = DirectoryReader.open(dir)) {
274+
result.numSegments = reader.leaves().size();
275+
} catch (IOException e) {
276+
throw new UncheckedIOException("Failed to get segment count for index at " + indexPath, e);
277+
}
278+
}
279+
268280
static class FormattedResults {
269281
List<Results> indexResults = new ArrayList<>();
270282
List<Results> queryResults = new ArrayList<>();

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727
import org.apache.lucene.document.KnnFloatVectorField;
2828
import org.apache.lucene.document.StoredField;
2929
import org.apache.lucene.index.ConcurrentMergeScheduler;
30-
import org.apache.lucene.index.DirectoryReader;
31-
import org.apache.lucene.index.IndexReader;
3230
import org.apache.lucene.index.IndexWriter;
3331
import org.apache.lucene.index.IndexWriterConfig;
3432
import org.apache.lucene.index.MergePolicy;
@@ -94,14 +92,6 @@ class KnnIndexer {
9492
this.mergePolicy = mergePolicy;
9593
}
9694

97-
void numSegments(KnnIndexTester.Results result) {
98-
try (FSDirectory dir = FSDirectory.open(indexPath); IndexReader reader = DirectoryReader.open(dir)) {
99-
result.numSegments = reader.leaves().size();
100-
} catch (IOException e) {
101-
throw new UncheckedIOException("Failed to get segment count for index at " + indexPath, e);
102-
}
103-
}
104-
10595
void createIndex(KnnIndexTester.Results result) throws IOException, InterruptedException, ExecutionException {
10696
IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.CREATE);
10797
iwc.setCodec(codec);
@@ -280,9 +270,11 @@ public void run() {
280270

281271
private void _run() throws IOException {
282272
while (true) {
283-
int id = numDocsIndexed.getAndIncrement();
284-
if (id >= numDocsToIndex) {
273+
int id = numDocsIndexed.get();
274+
if (id == numDocsToIndex) {
285275
break;
276+
} else if (numDocsIndexed.compareAndSet(id, id + 1) == false) {
277+
continue;
286278
}
287279

288280
Document doc = new Document();

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,29 +106,57 @@ KMeansIntermediate clusterAndSplit(final FloatVectorValues vectors, final int ta
106106
// TODO: consider adding cluster size counts to the kmeans algo
107107
// handle assignment here so we can track distance and cluster size
108108
int[] centroidVectorCount = new int[centroids.length];
109+
int effectiveCluster = -1;
109110
int effectiveK = 0;
110111
for (int assigment : assignments) {
111112
centroidVectorCount[assigment]++;
112113
// this cluster has received an assignment, its now effective, but only count it once
113114
if (centroidVectorCount[assigment] == 1) {
114115
effectiveK++;
116+
effectiveCluster = assigment;
115117
}
116118
}
117119

118120
if (effectiveK == 1) {
121+
final float[][] singleClusterCentroid = new float[1][];
122+
singleClusterCentroid[0] = centroids[effectiveCluster];
123+
kMeansIntermediate.setCentroids(singleClusterCentroid);
124+
Arrays.fill(kMeansIntermediate.assignments(), 0);
119125
return kMeansIntermediate;
120126
}
121127

128+
int removedElements = 0;
122129
for (int c = 0; c < centroidVectorCount.length; c++) {
123130
// Recurse for each cluster which is larger than targetSize
124131
// Give ourselves 30% margin for the target size
125-
if (100 * centroidVectorCount[c] > 134 * targetSize) {
126-
FloatVectorValues sample = createClusterSlice(centroidVectorCount[c], c, vectors, assignments);
127-
132+
final int count = centroidVectorCount[c];
133+
final int adjustedCentroid = c - removedElements;
134+
if (100 * count > 134 * targetSize) {
135+
final FloatVectorValues sample = createClusterSlice(count, adjustedCentroid, vectors, assignments);
128136
// TODO: consider iterative here instead of recursive
129137
// recursive call to build out the sub partitions around this centroid c
130138
// subsequently reconcile and flatten the space of all centroids and assignments into one structure we can return
131-
updateAssignmentsWithRecursiveSplit(kMeansIntermediate, c, clusterAndSplit(sample, targetSize));
139+
updateAssignmentsWithRecursiveSplit(kMeansIntermediate, adjustedCentroid, clusterAndSplit(sample, targetSize));
140+
} else if (count == 0) {
141+
// remove empty clusters
142+
final int newSize = kMeansIntermediate.centroids().length - 1;
143+
final float[][] newCentroids = new float[newSize][];
144+
System.arraycopy(kMeansIntermediate.centroids(), 0, newCentroids, 0, adjustedCentroid);
145+
System.arraycopy(
146+
kMeansIntermediate.centroids(),
147+
adjustedCentroid + 1,
148+
newCentroids,
149+
adjustedCentroid,
150+
newSize - adjustedCentroid
151+
);
152+
// we need to update the assignments to reflect the new centroid ordinals
153+
for (int i = 0; i < kMeansIntermediate.assignments().length; i++) {
154+
if (kMeansIntermediate.assignments()[i] > adjustedCentroid) {
155+
kMeansIntermediate.assignments()[i]--;
156+
}
157+
}
158+
kMeansIntermediate.setCentroids(newCentroids);
159+
removedElements++;
132160
}
133161
}
134162

server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeansTests.java

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,63 @@ private static FloatVectorValues generateData(int nSamples, int nDims, int nClus
7474
}
7575
return FloatVectorValues.fromFloats(vectors, nDims);
7676
}
77+
78+
public void testFewDifferentValues() throws IOException {
79+
int nVectors = random().nextInt(100, 1000);
80+
int targetSize = random().nextInt(4, 64);
81+
int dims = random().nextInt(2, 20);
82+
int diffValues = randomIntBetween(1, 5);
83+
float[][] values = new float[diffValues][dims];
84+
for (int i = 0; i < diffValues; i++) {
85+
for (int j = 0; j < dims; j++) {
86+
values[i][j] = random().nextFloat();
87+
}
88+
}
89+
List<float[]> vectorList = new ArrayList<>(nVectors);
90+
for (int i = 0; i < nVectors; i++) {
91+
vectorList.add(values[random().nextInt(diffValues)]);
92+
}
93+
FloatVectorValues vectors = FloatVectorValues.fromFloats(vectorList, dims);
94+
95+
HierarchicalKMeans hkmeans = new HierarchicalKMeans(
96+
dims,
97+
random().nextInt(1, 100),
98+
random().nextInt(Math.min(nVectors, 100), nVectors + 1),
99+
random().nextInt(2, 512),
100+
random().nextFloat(0.5f, 1.5f)
101+
);
102+
103+
KMeansResult result = hkmeans.cluster(vectors, targetSize);
104+
105+
float[][] centroids = result.centroids();
106+
int[] assignments = result.assignments();
107+
int[] soarAssignments = result.soarAssignments();
108+
109+
int[] counts = new int[centroids.length];
110+
for (int i = 0; i < assignments.length; i++) {
111+
counts[assignments[i]]++;
112+
}
113+
int totalCount = 0;
114+
for (int count : counts) {
115+
totalCount += count;
116+
assertTrue(count > 0);
117+
}
118+
assertEquals(nVectors, totalCount);
119+
120+
assertEquals(nVectors, assignments.length);
121+
122+
for (int assignment : assignments) {
123+
assertTrue(assignment >= 0 && assignment < centroids.length);
124+
}
125+
if (centroids.length > 1 && centroids.length < nVectors) {
126+
assertEquals(nVectors, soarAssignments.length);
127+
// verify no duplicates exist
128+
for (int i = 0; i < assignments.length; i++) {
129+
assertTrue(soarAssignments[i] >= 0 && soarAssignments[i] < centroids.length);
130+
assertNotEquals(assignments[i], soarAssignments[i]);
131+
}
132+
} else {
133+
assertEquals(0, soarAssignments.length);
134+
}
135+
}
77136
}

x-pack/plugin/esql/compute/build.gradle

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def addOccurrence(props, Occurrence) {
8686
newProps["Occurrence"] = Occurrence
8787
newProps["First"] = Occurrence == "First" ? "true" : ""
8888
newProps["Last"] = Occurrence == "Last" ? "true" : ""
89+
newProps["occurrence"] = Occurrence.toLowerCase(Locale.ROOT)
8990
return newProps
9091
}
9192

@@ -469,6 +470,27 @@ tasks.named('stringTemplates').configure {
469470
it.inputFile = stateInputFile
470471
it.outputFile = "org/elasticsearch/compute/aggregation/DoubleState.java"
471472
}
473+
474+
/*
475+
* Generates pairwise states. We generate the ones that we need at the moment,
476+
* but add more if you need more.
477+
*/
478+
File twoStateInputFile = file("src/main/java/org/elasticsearch/compute/aggregation/X-2State.java.st")
479+
[longProperties].forEach { v1 ->
480+
[intProperties, longProperties, floatProperties, doubleProperties].forEach { v2 ->
481+
{
482+
var properties = [:]
483+
v1.forEach { k, v -> properties["v1_" + k] = v}
484+
v2.forEach { k, v -> properties["v2_" + k] = v}
485+
template {
486+
it.properties = properties
487+
it.inputFile = twoStateInputFile
488+
it.outputFile = "org/elasticsearch/compute/aggregation/${v1.Type}${v2.Type}State.java"
489+
}
490+
}
491+
}
492+
}
493+
472494
File fallibleStateInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/aggregation/X-FallibleState.java.st")
473495
template {
474496
it.properties = booleanProperties

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorFunctionSupplierImplementer.java

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,15 @@
2525
import java.util.stream.Collectors;
2626
import java.util.stream.Stream;
2727

28+
import javax.lang.model.element.ExecutableElement;
2829
import javax.lang.model.element.Modifier;
2930
import javax.lang.model.element.TypeElement;
3031
import javax.lang.model.util.Elements;
3132

33+
import static org.elasticsearch.compute.gen.Methods.optionalStaticMethod;
34+
import static org.elasticsearch.compute.gen.Methods.requireArgs;
35+
import static org.elasticsearch.compute.gen.Methods.requireName;
36+
import static org.elasticsearch.compute.gen.Methods.requireType;
3237
import static org.elasticsearch.compute.gen.Types.AGGREGATOR_FUNCTION_SUPPLIER;
3338
import static org.elasticsearch.compute.gen.Types.DRIVER_CONTEXT;
3439
import static org.elasticsearch.compute.gen.Types.LIST_AGG_FUNC_DESC;
@@ -210,17 +215,24 @@ private MethodSpec describe() {
210215
MethodSpec.Builder builder = MethodSpec.methodBuilder("describe").returns(String.class);
211216
builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC);
212217

213-
String name = declarationType.getSimpleName().toString();
214-
name = name.replace("BytesRef", "Byte"); // The hack expects one word types so let's make BytesRef into Byte
215-
String[] parts = name.split("(?=\\p{Upper})");
216-
if (false == parts[parts.length - 1].equals("Aggregator") || parts.length < 3) {
217-
throw new IllegalArgumentException("Can't generate description for " + declarationType.getSimpleName());
218+
ExecutableElement describe = optionalStaticMethod(declarationType, requireType(STRING), requireName("describe"), requireArgs());
219+
if (describe == null) {
220+
String name = declarationType.getSimpleName().toString();
221+
name = name.replace("BytesRef", "Byte"); // The hack expects one word types so let's make BytesRef into Byte
222+
String[] parts = name.split("(?=\\p{Upper})");
223+
if (false == parts[parts.length - 1].equals("Aggregator") || parts.length < 3) {
224+
throw new IllegalArgumentException("Can't generate description for " + declarationType.getSimpleName());
225+
}
226+
227+
String operation = Arrays.stream(parts, 0, parts.length - 2)
228+
.map(s -> s.toLowerCase(Locale.ROOT))
229+
.collect(Collectors.joining("_"));
230+
String type = parts[parts.length - 2];
231+
232+
builder.addStatement("return $S", operation + " of " + type.toLowerCase(Locale.ROOT) + "s");
233+
} else {
234+
builder.addStatement("return $T.$L()", declarationType, "describe");
218235
}
219-
220-
String operation = Arrays.stream(parts, 0, parts.length - 2).map(s -> s.toLowerCase(Locale.ROOT)).collect(Collectors.joining("_"));
221-
String type = parts[parts.length - 2];
222-
223-
builder.addStatement("return $S", operation + " of " + type.toLowerCase(Locale.ROOT) + "s");
224236
return builder.build();
225237
}
226238
}

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ private TypeSpec type() {
198198
builder.addMethod(addRawInputLoop(groupIdClass, true));
199199
builder.addMethod(addIntermediateInput(groupIdClass));
200200
}
201+
builder.addMethod(maybeEnableGroupIdTracking());
201202
builder.addMethod(selectedMayContainUnseenGroups());
202203
builder.addMethod(evaluateIntermediate());
203204
builder.addMethod(evaluateFinal());
@@ -321,9 +322,11 @@ private MethodSpec prepareProcessRawInputPage() {
321322
builder.addStatement("$T $L = $L.asVector()", vectorType(p.type()), p.vectorName(), p.blockName());
322323
builder.beginControlFlow("if ($L == null)", p.vectorName());
323324
{
324-
builder.beginControlFlow("if ($L.mayHaveNulls())", p.blockName());
325-
builder.addStatement("state.enableGroupIdTracking(seenGroupIds)");
326-
builder.endControlFlow();
325+
builder.addStatement(
326+
"maybeEnableGroupIdTracking(seenGroupIds, "
327+
+ aggParams.stream().map(AggregationParameter::blockName).collect(joining(", "))
328+
+ ")"
329+
);
327330
returnAddInput(builder, false);
328331
}
329332
builder.endControlFlow();
@@ -351,6 +354,23 @@ private void returnAddInput(MethodSpec.Builder builder, boolean valuesAreVector)
351354
}
352355
}
353356

357+
private MethodSpec maybeEnableGroupIdTracking() {
358+
MethodSpec.Builder builder = MethodSpec.methodBuilder("maybeEnableGroupIdTracking");
359+
builder.addModifiers(Modifier.PRIVATE).returns(TypeName.VOID);
360+
builder.addParameter(SEEN_GROUP_IDS, "seenGroupIds");
361+
for (AggregationParameter p : aggParams) {
362+
builder.addParameter(blockType(p.type()), p.blockName());
363+
}
364+
365+
for (AggregationParameter p : aggParams) {
366+
builder.beginControlFlow("if ($L.mayHaveNulls())", p.blockName());
367+
builder.addStatement("state.enableGroupIdTracking(seenGroupIds)");
368+
builder.endControlFlow();
369+
}
370+
371+
return builder.build();
372+
}
373+
354374
/**
355375
* Generate an {@code AddInput} implementation. That's a collection path optimized for the input data.
356376
*/

0 commit comments

Comments
 (0)